diff --git a/docs/changelog/122134.yaml b/docs/changelog/122134.yaml new file mode 100644 index 0000000000000..25ca556789525 --- /dev/null +++ b/docs/changelog/122134.yaml @@ -0,0 +1,5 @@ +pr: 122134 +summary: Adding integration for VoyageAI embeddings and rerank models +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 03c55ecaac40b..809d44ce13d98 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -180,6 +180,7 @@ static TransportVersion def(int id) { public static final TransportVersion REMOVE_ALL_APPLICABLE_SELECTOR_BACKPORT_8_19 = def(8_841_0_02); public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE_BACKPORT_8_19 = def(8_841_0_03); public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS_BACKPORT_8_19 = def(8_841_0_04); + public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02); @@ -199,7 +200,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS = def(9_011_0_00); public static final TransportVersion REMOVE_REPOSITORY_CONFLICT_MESSAGE = def(9_012_0_00); public static final TransportVersion RERANKER_FAILURES_ALLOWED = def(9_013_0_00); - + public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_014_0_00); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 9d4cec798964a..859a065b6e1a0 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { @SuppressWarnings("unchecked") public void testGetServicesWithoutTaskType() throws IOException { List services = getAllServices(); - assertThat(services.size(), equalTo(19)); + assertThat(services.size(), equalTo(20)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -53,6 +53,7 @@ public void testGetServicesWithoutTaskType() throws IOException { "test_reranking_service", "test_service", "text_embedding_test_service", + "voyageai", "watsonxai" ).toArray(), providers @@ -62,7 +63,7 @@ public void testGetServicesWithoutTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithTextEmbeddingTaskType() throws IOException { List services = getServices(TaskType.TEXT_EMBEDDING); - assertThat(services.size(), equalTo(14)); + assertThat(services.size(), equalTo(15)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -85,6 +86,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { "mistral", "openai", "text_embedding_test_service", + "voyageai", "watsonxai" ).toArray(), providers @@ -94,7 +96,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithRerankTaskType() throws IOException { List services = getServices(TaskType.RERANK); - assertThat(services.size(), equalTo(6)); + assertThat(services.size(), equalTo(7)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -103,7 +105,8 @@ public void testGetServicesWithRerankTaskType() throws IOException { } assertArrayEquals( - List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service").toArray(), + List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai") + .toArray(), providers ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 563247de44f81..ef79452e94c74 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -90,6 +90,11 @@ import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; import java.util.ArrayList; import java.util.List; @@ -142,6 +147,7 @@ public static List getNamedWriteables() { addEisNamedWriteables(namedWriteables); addAlibabaCloudSearchNamedWriteables(namedWriteables); addJinaAINamedWriteables(namedWriteables); + addVoyageAINamedWriteables(namedWriteables); addUnifiedNamedWriteables(namedWriteables); @@ -626,6 +632,28 @@ private static void addJinaAINamedWriteables(List ); } + private static void addVoyageAINamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIServiceSettings.NAME, VoyageAIServiceSettings::new) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + VoyageAIEmbeddingsServiceSettings.NAME, + VoyageAIEmbeddingsServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIEmbeddingsTaskSettings.NAME, VoyageAIEmbeddingsTaskSettings::new) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIRerankServiceSettings.NAME, VoyageAIRerankServiceSettings::new) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIRerankTaskSettings.NAME, VoyageAIRerankTaskSettings::new) + ); + } + private static void addEisNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index d865b241bb4e0..0b01ad5e3c66f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -128,6 +128,7 @@ import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.ArrayList; @@ -359,6 +360,7 @@ public List getInferenceServiceFactories() { context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()), context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()), context -> new JinaAIService(httpFactory.get(), serviceComponents.get()), + context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java new file mode 100644 index 0000000000000..6a4a9e5f93639 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.voyageai; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIRerankRequestManager; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; + +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; + +/** + * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the voyageai model type. + */ +public class VoyageAIActionCreator implements VoyageAIActionVisitor { + private final Sender sender; + private final ServiceComponents serviceComponents; + + public VoyageAIActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(VoyageAIEmbeddingsModel model, Map taskSettings, InputType inputType) { + var overriddenModel = VoyageAIEmbeddingsModel.of(model, taskSettings, inputType); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI embeddings"); + var requestCreator = VoyageAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool()); + return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + } + + @Override + public ExecutableAction create(VoyageAIRerankModel model, Map taskSettings) { + var overriddenModel = VoyageAIRerankModel.of(model, taskSettings); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI rerank"); + var requestCreator = VoyageAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool()); + return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionVisitor.java new file mode 100644 index 0000000000000..d6732dba95475 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionVisitor.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.voyageai; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; + +import java.util.Map; + +public interface VoyageAIActionVisitor { + ExecutableAction create(VoyageAIEmbeddingsModel model, Map taskSettings, InputType inputType); + + ExecutableAction create(VoyageAIRerankModel model, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIEmbeddingsRequestManager.java new file mode 100644 index 0000000000000..be186dc54f8f2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIEmbeddingsRequestManager.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class VoyageAIEmbeddingsRequestManager extends VoyageAIRequestManager { + private static final Logger logger = LogManager.getLogger(VoyageAIEmbeddingsRequestManager.class); + private static final ResponseHandler HANDLER = createEmbeddingsHandler(); + + private static ResponseHandler createEmbeddingsHandler() { + return new VoyageAIResponseHandler("voyageai text embedding", VoyageAIEmbeddingsResponseEntity::fromResponse); + } + + public static VoyageAIEmbeddingsRequestManager of(VoyageAIEmbeddingsModel model, ThreadPool threadPool) { + return new VoyageAIEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final VoyageAIEmbeddingsModel model; + + private VoyageAIEmbeddingsRequestManager(VoyageAIEmbeddingsModel model, ThreadPool threadPool) { + super(threadPool, model); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(docsInput, model); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRequestManager.java new file mode 100644 index 0000000000000..99a0617ff510a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRequestManager.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel; + +import java.util.Map; +import java.util.Objects; + +abstract class VoyageAIRequestManager extends BaseRequestManager { + private static final String DEFAULT_MODEL_FAMILY = "default_model_family"; + private static final Map MODEL_TO_MODEL_FAMILY = Map.of( + "voyage-multimodal-3", + "embed_multimodal", + "voyage-3-large", + "embed_large", + "voyage-code-3", + "embed_large", + "voyage-3", + "embed_medium", + "voyage-3-lite", + "embed_small", + "voyage-finance-2", + "embed_large", + "voyage-law-2", + "embed_large", + "voyage-code-2", + "embed_large", + "rerank-2", + "rerank_large", + "rerank-2-lite", + "rerank_small" + ); + + protected VoyageAIRequestManager(ThreadPool threadPool, VoyageAIModel model) { + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + } + + record RateLimitGrouping(int apiKeyHash) { + public static RateLimitGrouping of(VoyageAIModel model) { + Objects.requireNonNull(model); + String modelId = model.getServiceSettings().modelId(); + String modelFamily = MODEL_TO_MODEL_FAMILY.getOrDefault(modelId, DEFAULT_MODEL_FAMILY); + + return new RateLimitGrouping(modelFamily.hashCode()); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRerankRequestManager.java new file mode 100644 index 0000000000000..ca91ad4dda276 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRerankRequestManager.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequest; +import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIRerankResponseEntity; +import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; + +import java.util.Objects; +import java.util.function.Supplier; + +public class VoyageAIRerankRequestManager extends VoyageAIRequestManager { + private static final Logger logger = LogManager.getLogger(VoyageAIRerankRequestManager.class); + private static final ResponseHandler HANDLER = createVoyageAIResponseHandler(); + + private static ResponseHandler createVoyageAIResponseHandler() { + return new VoyageAIResponseHandler("voyageai rerank", (request, response) -> VoyageAIRerankResponseEntity.fromResponse(response)); + } + + public static VoyageAIRerankRequestManager of(VoyageAIRerankModel model, ThreadPool threadPool) { + return new VoyageAIRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final VoyageAIRerankModel model; + + private VoyageAIRerankRequestManager(VoyageAIRerankModel model, ThreadPool threadPool) { + super(threadPool, model); + this.model = model; + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var rerankInput = QueryAndDocsInputs.of(inferenceInputs); + VoyageAIRerankRequest request = new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java new file mode 100644 index 0000000000000..a24f5dc8e14ea --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java @@ -0,0 +1,87 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +public class VoyageAIEmbeddingsRequest extends VoyageAIRequest { + + private final VoyageAIAccount account; + private final List input; + private final VoyageAIEmbeddingsServiceSettings serviceSettings; + private final VoyageAIEmbeddingsTaskSettings taskSettings; + private final String model; + private final String inferenceEntityId; + + public VoyageAIEmbeddingsRequest(List input, VoyageAIEmbeddingsModel embeddingsModel) { + Objects.requireNonNull(embeddingsModel); + + account = VoyageAIAccount.of(embeddingsModel); + this.input = Objects.requireNonNull(input); + serviceSettings = embeddingsModel.getServiceSettings(); + taskSettings = embeddingsModel.getTaskSettings(); + model = embeddingsModel.getServiceSettings().getCommonSettings().modelId(); + inferenceEntityId = embeddingsModel.getInferenceEntityId(); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(account.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new VoyageAIEmbeddingsRequestEntity(input, serviceSettings, taskSettings, model)) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + decorateWithHeaders(httpPost, account); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public URI getURI() { + return account.uri(); + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } + + public VoyageAIEmbeddingsTaskSettings getTaskSettings() { + return taskSettings; + } + + public VoyageAIEmbeddingsServiceSettings getServiceSettings() { + return serviceSettings; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..8191443edf75c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings.invalidInputTypeMessage; + +public record VoyageAIEmbeddingsRequestEntity( + List input, + VoyageAIEmbeddingsServiceSettings serviceSettings, + VoyageAIEmbeddingsTaskSettings taskSettings, + String model +) implements ToXContentObject { + + private static final String DOCUMENT = "document"; + private static final String QUERY = "query"; + private static final String INPUT_FIELD = "input"; + private static final String MODEL_FIELD = "model"; + public static final String INPUT_TYPE_FIELD = "input_type"; + public static final String TRUNCATION_FIELD = "truncation"; + public static final String OUTPUT_DIMENSION = "output_dimension"; + static final String OUTPUT_DTYPE_FIELD = "output_dtype"; + + public VoyageAIEmbeddingsRequestEntity { + Objects.requireNonNull(input); + Objects.requireNonNull(model); + Objects.requireNonNull(taskSettings); + Objects.requireNonNull(serviceSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INPUT_FIELD, input); + builder.field(MODEL_FIELD, model); + + var inputType = convertToString(taskSettings.getInputType()); + if (inputType != null) { + builder.field(INPUT_TYPE_FIELD, inputType); + } + + if (taskSettings.getTruncation() != null) { + builder.field(TRUNCATION_FIELD, taskSettings.getTruncation()); + } + + if (serviceSettings.dimensions() != null) { + builder.field(OUTPUT_DIMENSION, serviceSettings.dimensions()); + } + + if (serviceSettings.getEmbeddingType() != null) { + builder.field(OUTPUT_DTYPE_FIELD, serviceSettings.getEmbeddingType().toRequestString()); + } + + builder.endObject(); + return builder; + } + + static String convertToString(InputType inputType) { + return switch (inputType) { + case null -> null; + case INGEST -> DOCUMENT; + case SEARCH -> QUERY; + default -> { + assert false : invalidInputTypeMessage(inputType); + yield null; + } + }; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequest.java new file mode 100644 index 0000000000000..5455a2f0301f2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequest.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public abstract class VoyageAIRequest implements Request { + + public static void decorateWithHeaders(HttpPost request, VoyageAIAccount account) { + request.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + request.setHeader(createAuthBearerHeader(account.apiKey())); + request.setHeader(VoyageAIUtils.createRequestSourceHeader()); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java new file mode 100644 index 0000000000000..37d15fe1fe2c5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +public class VoyageAIRerankRequest extends VoyageAIRequest { + + private final VoyageAIAccount account; + private final String query; + private final List input; + private final VoyageAIRerankTaskSettings taskSettings; + private final String model; + private final String inferenceEntityId; + + public VoyageAIRerankRequest(String query, List input, VoyageAIRerankModel model) { + Objects.requireNonNull(model); + + this.account = VoyageAIAccount.of(model); + this.input = Objects.requireNonNull(input); + this.query = Objects.requireNonNull(query); + taskSettings = model.getTaskSettings(); + this.model = model.getServiceSettings().modelId(); + inferenceEntityId = model.getInferenceEntityId(); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(account.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new VoyageAIRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + decorateWithHeaders(httpPost, account); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public URI getURI() { + return account.uri(); + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java new file mode 100644 index 0000000000000..0f7baaa35044e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record VoyageAIRerankRequestEntity(String model, String query, List documents, VoyageAIRerankTaskSettings taskSettings) + implements + ToXContentObject { + + private static final String DOCUMENTS_FIELD = "documents"; + private static final String QUERY_FIELD = "query"; + private static final String MODEL_FIELD = "model"; + public static final String TRUNCATION_FIELD = "truncation"; + public static final String RETURN_DOCUMENTS_FIELD = "return_documents"; + + public VoyageAIRerankRequestEntity { + Objects.requireNonNull(query); + Objects.requireNonNull(documents); + Objects.requireNonNull(model); + Objects.requireNonNull(taskSettings); + } + + public VoyageAIRerankRequestEntity(String query, List input, VoyageAIRerankTaskSettings taskSettings, String model) { + this(model, query, input, taskSettings != null ? taskSettings : VoyageAIRerankTaskSettings.EMPTY_SETTINGS); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(MODEL_FIELD, model); + builder.field(QUERY_FIELD, query); + builder.field(DOCUMENTS_FIELD, documents); + + if (taskSettings.getDoesReturnDocuments() != null) { + builder.field(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments()); + } + + if (taskSettings.getTopKDocumentsOnly() != null) { + builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, taskSettings.getTopKDocumentsOnly()); + } + + if (taskSettings.getTruncation() != null) { + builder.field(TRUNCATION_FIELD, taskSettings.getTruncation()); + } + + builder.endObject(); + return builder; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtils.java new file mode 100644 index 0000000000000..130093826ee0f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtils.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.apache.http.Header; +import org.apache.http.message.BasicHeader; + +public class VoyageAIUtils { + public static final String HOST = "api.voyageai.com"; + public static final String VERSION_1 = "v1"; + public static final String EMBEDDINGS_PATH = "embeddings"; + public static final String RERANK_PATH = "rerank"; + public static final String REQUEST_SOURCE_HEADER = "Request-Source"; + public static final String ELASTIC_REQUEST_SOURCE = "unspecified:elasticsearch"; + + public static Header createRequestSourceHeader() { + return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE); + } + + private VoyageAIUtils() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java new file mode 100644 index 0000000000000..f2c6cb5db4c32 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java @@ -0,0 +1,197 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.inference.external.response.voyageai; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType.toLowerCase; + +public class VoyageAIEmbeddingsResponseEntity { + private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes(); + + private static String supportedEmbeddingTypes() { + String[] validTypes = new String[] { + toLowerCase(VoyageAIEmbeddingType.FLOAT), + toLowerCase(VoyageAIEmbeddingType.INT8), + toLowerCase(VoyageAIEmbeddingType.BIT) }; + Arrays.sort(validTypes); + return String.join(", ", validTypes); + } + + record EmbeddingInt8Result(List entries) { + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + EmbeddingInt8Result.class.getSimpleName(), + true, + args -> new EmbeddingInt8Result((List) args[0]) + ); + + static { + PARSER.declareObjectArray(constructorArg(), EmbeddingInt8ResultEntry.PARSER::apply, new ParseField("data")); + } + } + + record EmbeddingInt8ResultEntry(Integer index, List embedding) { + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + EmbeddingInt8ResultEntry.class.getSimpleName(), + true, + args -> new EmbeddingInt8ResultEntry((Integer) args[0], (List) args[1]) + ); + + static { + PARSER.declareInt(constructorArg(), new ParseField("index")); + PARSER.declareIntArray(constructorArg(), new ParseField("embedding")); + } + + private static void checkByteBounds(Integer value) { + if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { + throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte"); + } + } + + public InferenceByteEmbedding toInferenceByteEmbedding() { + embedding.forEach(EmbeddingInt8ResultEntry::checkByteBounds); + return InferenceByteEmbedding.of(embedding.stream().map(Integer::byteValue).toList()); + } + } + + record EmbeddingFloatResult(List entries) { + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + EmbeddingFloatResult.class.getSimpleName(), + true, + args -> new EmbeddingFloatResult((List) args[0]) + ); + + static { + PARSER.declareObjectArray(constructorArg(), EmbeddingFloatResultEntry.PARSER::apply, new ParseField("data")); + } + } + + record EmbeddingFloatResultEntry(Integer index, List embedding) { + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + EmbeddingFloatResultEntry.class.getSimpleName(), + true, + args -> new EmbeddingFloatResultEntry((Integer) args[0], (List) args[1]) + ); + + static { + PARSER.declareInt(constructorArg(), new ParseField("index")); + PARSER.declareFloatArray(constructorArg(), new ParseField("embedding")); + } + + public InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding toInferenceFloatEmbedding() { + return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embedding); + } + } + + /** + * Parses the VoyageAI json response. + * For a request like: + * + *
+     *     
+     *        {
+     *          "input": [
+     *            "Sample text 1",
+     *            "Sample text 2"
+     *          ],
+     *          "model": "voyage-3-large"
+     *        }
+     *     
+     * 
+ * + * The response would look like: + * + *
+     * 
+     * {
+     *  "object": "list",
+     *  "data": [
+     *      {
+     *          "object": "embedding",
+     *          "embedding": [
+     *              -0.009327292,
+     *              -0.0028842222,
+     *          ],
+     *          "index": 0
+     *      },
+     *      {
+     *          "object": "embedding",
+     *          "embedding": [ ... ],
+     *          "index": 1
+     *      }
+     *  ],
+     *  "model": "voyage-3-large",
+     *  "usage": {
+     *      "total_tokens": 10
+     *  }
+     * }
+     * 
+     * 
+ */ + public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + VoyageAIEmbeddingType embeddingType = ((VoyageAIEmbeddingsRequest) request).getServiceSettings().getEmbeddingType(); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + if (embeddingType == null || embeddingType == VoyageAIEmbeddingType.FLOAT) { + var embeddingResult = EmbeddingFloatResult.PARSER.apply(jsonParser, null); + + List embeddingList = embeddingResult.entries.stream() + .map(EmbeddingFloatResultEntry::toInferenceFloatEmbedding) + .toList(); + return new InferenceTextEmbeddingFloatResults(embeddingList); + } else if (embeddingType == VoyageAIEmbeddingType.INT8) { + var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null); + List embeddingList = embeddingResult.entries.stream() + .map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding) + .toList(); + return new InferenceTextEmbeddingByteResults(embeddingList); + } else if (embeddingType == VoyageAIEmbeddingType.BIT || embeddingType == VoyageAIEmbeddingType.BINARY) { + var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null); + List embeddingList = embeddingResult.entries.stream() + .map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding) + .toList(); + return new InferenceTextEmbeddingBitResults(embeddingList); + } else { + throw new IllegalArgumentException( + "Illegal embedding_type value: " + embeddingType + ". Supported types are: " + VALID_EMBEDDING_TYPES_STRING + ); + } + } + } + + private VoyageAIEmbeddingsResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntity.java new file mode 100644 index 0000000000000..41ffaa61fcd26 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntity.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.voyageai; + +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +public class VoyageAIErrorResponseEntity extends ErrorResponse { + + private VoyageAIErrorResponseEntity(String errorMessage) { + super(errorMessage); + } + + /** + * Parse an HTTP response into a VoyageAIErrorResponseEntity + * + * @param response The error response + * @return An error entity if the response is JSON with a `detail` field containing the error message + * or null if the response does not contain the message field + */ + public static ErrorResponse fromResponse(HttpResult response) { + try ( + XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + var responseMap = jsonParser.map(); + var message = (String) responseMap.get("detail"); + if (message != null) { + return new VoyageAIErrorResponseEntity(message); + } + } catch (Exception e) { + // swallow the error + } + + return ErrorResponse.UNDEFINED_ERROR; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java new file mode 100644 index 0000000000000..5438ba3644753 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.inference.external.response.voyageai; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class VoyageAIRerankResponseEntity { + + private static final Logger logger = LogManager.getLogger(VoyageAIRerankResponseEntity.class); + + record RerankResult(List entries) { + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + RerankResult.class.getSimpleName(), + true, + args -> new RerankResult((List) args[0]) + ); + + static { + PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("data")); + } + } + + record RerankResultEntry(Float relevanceScore, Integer index, @Nullable String document) { + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + RerankResultEntry.class.getSimpleName(), + args -> new RerankResultEntry((Float) args[0], (Integer) args[1], (String) args[2]) + ); + + static { + PARSER.declareFloat(constructorArg(), new ParseField("relevance_score")); + PARSER.declareInt(constructorArg(), new ParseField("index")); + PARSER.declareString(optionalConstructorArg(), new ParseField("document")); + } + + public RankedDocsResults.RankedDoc toRankedDoc() { + return new RankedDocsResults.RankedDoc(index, relevanceScore, document); + } + } + + /** + * Parses the VoyageAI ranked response. + * For a request like: + * "model": "rerank-2", + * "query": "What is the capital of the United States?", + * "top_k": 2, + * "documents": ["Carson City is the capital city of the American state of Nevada.", + * "The Commonwealth of the Northern Mariana ... Its capital is Saipan.", + * "Washington, D.C. (also known as simply Washington or D.C., ... It is a federal district.", + * "Capital punishment (the death penalty) ... As of 2017, capital punishment is legal in 30 of the 50 states."] + *

+ * The response will look like (without whitespace): + * { + * "object": "list", + * "data": [ + * { + * "relevance_score": 0.4375, + * "index": 0 + * }, + * { + * "relevance_score": 0.421875, + * "index": 1 + * } + * ], + * "model": "rerank-2", + * "usage": { + * "total_tokens": 26 + * } + * } + * @param response the http response from VoyageAI + * @return the parsed response + * @throws IOException if there is an error parsing the response + */ + public static InferenceServiceResults fromResponse(HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + var rerankResult = RerankResult.PARSER.apply(jsonParser, null); + + return new RankedDocsResults(rerankResult.entries.stream().map(RerankResultEntry::toRankedDoc).toList()); + } + } + + private VoyageAIRerankResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIAccount.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIAccount.java new file mode 100644 index 0000000000000..bb7de2fc8ad0c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIAccount.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.voyageai; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Objects; + +public record VoyageAIAccount(URI uri, SecureString apiKey) { + + public static VoyageAIAccount of(VoyageAIModel model) { + try { + var uri = model.buildUri(); + return new VoyageAIAccount(uri, model.apiKey()); + } catch (URISyntaxException e) { + // using bad request here so that potentially sensitive URL information does not get logged + throw new ElasticsearchStatusException("Failed to construct VoyageAI URL", RestStatus.BAD_REQUEST, e); + } + } + + public VoyageAIAccount { + Objects.requireNonNull(uri); + Objects.requireNonNull(apiKey); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandler.java new file mode 100644 index 0000000000000..1611426034a52 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandler.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.voyageai; + +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIErrorResponseEntity; + +/** + * Defines how to handle various errors returned from the VoyageAI integration. + * + */ +public class VoyageAIResponseHandler extends BaseResponseHandler { + static final String VALIDATION_ERROR_MESSAGE = "Received an input validation error response"; + static final String PAYMENT_ERROR_MESSAGE = "Payment required"; + + public VoyageAIResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, VoyageAIErrorResponseEntity::fromResponse); + } + + /** + * Validates the status code throws an RetryException if not in the range [200, 300). + * + * @param request The http request + * @param result The http response and body + * @throws RetryException Throws if status code is {@code >= 300 or < 200 } + */ + @Override + protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { + if (result.isSuccessfulResponse()) { + return; + } + + // handle error codes + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode == 500) { + throw new RetryException(true, buildError(SERVER_ERROR, request, result)); + } else if (statusCode > 500) { + throw new RetryException(false, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 429) { + throw new RetryException(true, buildError(RATE_LIMIT, request, result)); + } else if (statusCode == 400 || statusCode == 422) { + throw new RetryException(false, buildError(VALIDATION_ERROR_MESSAGE, request, result)); + } else if (statusCode == 401) { + throw new RetryException(false, buildError(AUTHENTICATION, request, result)); + } else if (statusCode == 402) { + throw new RetryException(false, buildError(PAYMENT_ERROR_MESSAGE, request, result)); + } else if (statusCode >= 300 && statusCode < 400) { + throw new RetryException(false, buildError(REDIRECTION, request, result)); + } else { + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIModel.java new file mode 100644 index 0000000000000..e63a716b96617 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIModel.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionVisitor; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; +import java.util.Objects; + +public abstract class VoyageAIModel extends Model { + private final SecureString apiKey; + private final VoyageAIRateLimitServiceSettings rateLimitServiceSettings; + protected final URI uri; + + public VoyageAIModel( + ModelConfigurations configurations, + ModelSecrets secrets, + @Nullable ApiKeySecrets apiKeySecrets, + VoyageAIRateLimitServiceSettings rateLimitServiceSettings + ) { + this(configurations, secrets, apiKeySecrets, rateLimitServiceSettings, null); + } + + public VoyageAIModel( + ModelConfigurations configurations, + ModelSecrets secrets, + @Nullable ApiKeySecrets apiKeySecrets, + VoyageAIRateLimitServiceSettings rateLimitServiceSettings, + String url + ) { + super(configurations, secrets); + + this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); + this.apiKey = ServiceUtils.apiKey(apiKeySecrets); + this.uri = url == null ? null : URI.create(url); + } + + protected VoyageAIModel(VoyageAIModel model, TaskSettings taskSettings) { + super(model, taskSettings); + + this.rateLimitServiceSettings = model.rateLimitServiceSettings(); + this.apiKey = model.apiKey(); + this.uri = model.uri; + } + + protected VoyageAIModel(VoyageAIModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + + this.rateLimitServiceSettings = model.rateLimitServiceSettings(); + this.apiKey = model.apiKey(); + this.uri = model.uri; + } + + public SecureString apiKey() { + return apiKey; + } + + public VoyageAIRateLimitServiceSettings rateLimitServiceSettings() { + return rateLimitServiceSettings; + } + + public abstract ExecutableAction accept(VoyageAIActionVisitor creator, Map taskSettings, InputType inputType); + + public URI uri() { + return uri; + } + + public URI buildUri() throws URISyntaxException { + if (uri == null) { + return buildRequestUri(); + } + return uri; + } + + protected abstract URI buildRequestUri() throws URISyntaxException; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIRateLimitServiceSettings.java new file mode 100644 index 0000000000000..a4b325fe2db41 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIRateLimitServiceSettings.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai; + +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public interface VoyageAIRateLimitServiceSettings { + RateLimitSettings rateLimitSettings(); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java new file mode 100644 index 0000000000000..f92779de9b7f5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -0,0 +1,397 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; + +public class VoyageAIService extends SenderService { + public static final String NAME = "voyageai"; + + private static final String SERVICE_NAME = "Voyage AI"; + private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK); + + private static final Integer DEFAULT_BATCH_SIZE = 7; + private static final Map MODEL_BATCH_SIZES = Map.of( + "voyage-multimodal-3", + 7, + "voyage-3-large", + 7, + "voyage-code-3", + 7, + "voyage-3", + 10, + "voyage-3-lite", + 30, + "voyage-finance-2", + 7, + "voyage-law-2", + 7, + "voyage-code-2", + 7, + "voyage-2", + 72, + "voyage-02", + 72 + ); + + public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + VoyageAIModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + private static VoyageAIModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + @Nullable Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + chunkingSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + private static VoyageAIModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + return switch (taskType) { + case TEXT_EMBEDDING -> new VoyageAIEmbeddingsModel( + inferenceEntityId, + NAME, + serviceSettings, + taskSettings, + chunkingSettings, + secretSettings, + context + ); + case RERANK -> new VoyageAIRerankModel(inferenceEntityId, NAME, serviceSettings, taskSettings, secretSettings, context); + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + }; + } + + @Override + public VoyageAIModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + secretSettingsMap, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public VoyageAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return supportedTaskTypes; + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + + @Override + public void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof VoyageAIModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + VoyageAIModel voyageaiModel = (VoyageAIModel) model; + var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents()); + + var action = voyageaiModel.accept(actionCreator, taskSettings, inputType); + action.execute(inputs, timeout, listener); + } + + @Override + protected void doChunkedInfer( + Model model, + DocumentsOnlyInput inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + if (model instanceof VoyageAIModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + VoyageAIModel voyageaiModel = (VoyageAIModel) model; + var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents()); + + List batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + getBatchSize(voyageaiModel), + EmbeddingRequestChunker.EmbeddingType.fromDenseVectorElementType(model.getServiceSettings().elementType()), + voyageaiModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = voyageaiModel.accept(actionCreator, taskSettings, inputType); + action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); + } + } + + private static int getBatchSize(VoyageAIModel model) { + return MODEL_BATCH_SIZES.getOrDefault(model.getServiceSettings().modelId(), DEFAULT_BATCH_SIZE); + } + + /** + * For text embedding models get the embedding size and + * update the service settings. + * + * @param model The new model + * @param listener The listener + */ + @Override + public void checkModelConfig(Model model, ActionListener listener) { + ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); + } + + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof VoyageAIEmbeddingsModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel; + var maxInputTokens = serviceSettings.maxInputTokens(); + var dimensionSetByUser = serviceSettings.dimensionsSetByUser(); + + var updatedServiceSettings = new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings( + serviceSettings.getCommonSettings().modelId(), + serviceSettings.getCommonSettings().rateLimitSettings() + ), + serviceSettings.getEmbeddingType(), + similarityToUse, + embeddingSize, + maxInputTokens, + dimensionSetByUser + ); + + return new VoyageAIEmbeddingsModel(embeddingsModel, updatedServiceSettings); + } else { + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } + } + + /** + * Return the default similarity measure for the embedding type. + * VoyageAI embeddings are normalized to unit vectors therefore Dot + * Product similarity can be used and is the default for all VoyageAI + * models. + * + * @return The default similarity. + */ + static SimilarityMeasure defaultSimilarity() { + return SimilarityMeasure.DOT_PRODUCT; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED; + } + + public static class Configuration { + public static InferenceServiceConfiguration get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable configuration = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + MODEL_ID, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "The name of the model to use for the inference task." + ) + .setLabel("Model ID") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(supportedTaskTypes) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceFields.java new file mode 100644 index 0000000000000..b36f212b61e5d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceFields.java @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai; + +public class VoyageAIServiceFields { + public static final String TRUNCATION = "truncation"; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java new file mode 100644 index 0000000000000..75497d1a4b4f0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +public class VoyageAIServiceSettings extends FilteredXContentObject implements ServiceSettings, VoyageAIRateLimitServiceSettings { + + public static final String NAME = "voyageai_service_settings"; + public static final String MODEL_ID = "model_id"; + private static final Logger logger = LogManager.getLogger(VoyageAIServiceSettings.class); + // See https://docs.voyageai.com/docs/rate-limits + public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(2_000); + + public static VoyageAIServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + VoyageAIService.NAME, + context + ); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new VoyageAIServiceSettings(modelId, rateLimitSettings); + } + + private final String modelId; + private final RateLimitSettings rateLimitSettings; + + public VoyageAIServiceSettings(String modelId, @Nullable RateLimitSettings rateLimitSettings) { + this.modelId = Objects.requireNonNull(modelId); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public VoyageAIServiceSettings(StreamInput in) throws IOException { + modelId = in.readString(); + rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragment(builder, params); + + builder.endObject(); + return builder; + } + + public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException { + return toXContentFragmentOfExposedFields(builder, params); + } + + @Override + public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VoyageAIServiceSettings that = (VoyageAIServiceSettings) o; + return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java new file mode 100644 index 0000000000000..db13e46b14641 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; + +import java.util.Arrays; +import java.util.EnumSet; +import java.util.Locale; +import java.util.Map; + +/** + * Defines the type of embedding that the VoyageAI api should return for a request. + * + *

+ * See api docs for details. + *

+ */ +public enum VoyageAIEmbeddingType { + /** + * Use this when you want to get back the default float embeddings. Valid for all models. + */ + FLOAT(DenseVectorFieldMapper.ElementType.FLOAT, RequestConstants.FLOAT), + /** + * Use this when you want to get back signed int8 embeddings. Valid for only v3 models. + */ + INT8(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8), + /** + * This is a synonym for INT8 + */ + BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8), + /** + * Use this when you want to get back binary embeddings. Valid only for v3 models. + */ + BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BINARY), + /** + * This is a synonym for BIT + */ + BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BINARY); + + private static final class RequestConstants { + private static final String FLOAT = "float"; + private static final String INT8 = "int8"; + private static final String BINARY = "binary"; + } + + private static final Map ELEMENT_TYPE_TO_VOYAGE_EMBEDDING = Map.of( + DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, + DenseVectorFieldMapper.ElementType.BYTE, + BYTE, + DenseVectorFieldMapper.ElementType.BIT, + BIT + ); + static final EnumSet SUPPORTED_ELEMENT_TYPES = EnumSet.copyOf( + ELEMENT_TYPE_TO_VOYAGE_EMBEDDING.keySet() + ); + + private final DenseVectorFieldMapper.ElementType elementType; + private final String requestString; + + VoyageAIEmbeddingType(DenseVectorFieldMapper.ElementType elementType, String requestString) { + this.elementType = elementType; + this.requestString = requestString; + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + + public String toRequestString() { + return requestString; + } + + public static String toLowerCase(VoyageAIEmbeddingType type) { + return type.toString().toLowerCase(Locale.ROOT); + } + + public static VoyageAIEmbeddingType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static VoyageAIEmbeddingType fromElementType(DenseVectorFieldMapper.ElementType elementType) { + var embedding = ELEMENT_TYPE_TO_VOYAGE_EMBEDDING.get(elementType); + + if (embedding == null) { + var validElementTypes = SUPPORTED_ELEMENT_TYPES.stream() + .map(value -> value.toString().toLowerCase(Locale.ROOT)) + .toArray(String[]::new); + Arrays.sort(validElementTypes); + + throw new IllegalArgumentException( + Strings.format( + "Element type [%s] does not map to a VoyageAI embedding value, must be one of [%s]", + elementType, + String.join(", ", validElementTypes) + ) + ); + } + + return embedding; + } + + public DenseVectorFieldMapper.ElementType toElementType() { + return elementType; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModel.java new file mode 100644 index 0000000000000..41194f6862a44 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModel.java @@ -0,0 +1,127 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.embeddings; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionVisitor; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils.HOST; + +public class VoyageAIEmbeddingsModel extends VoyageAIModel { + public static VoyageAIEmbeddingsModel of(VoyageAIEmbeddingsModel model, Map taskSettings, InputType inputType) { + var requestTaskSettings = VoyageAIEmbeddingsTaskSettings.fromMap(taskSettings); + return new VoyageAIEmbeddingsModel( + model, + VoyageAIEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings, inputType) + ); + } + + public VoyageAIEmbeddingsModel( + String inferenceId, + String service, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceId, + service, + VoyageAIEmbeddingsServiceSettings.fromMap(serviceSettings, context), + VoyageAIEmbeddingsTaskSettings.fromMap(taskSettings), + chunkingSettings, + DefaultSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + VoyageAIEmbeddingsModel( + String modelId, + String service, + VoyageAIEmbeddingsServiceSettings serviceSettings, + VoyageAIEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(modelId, TaskType.TEXT_EMBEDDING, service, serviceSettings, taskSettings, chunkingSettings), + new ModelSecrets(secretSettings), + secretSettings, + serviceSettings.getCommonSettings() + ); + } + + VoyageAIEmbeddingsModel( + String modelId, + String service, + String url, + VoyageAIEmbeddingsServiceSettings serviceSettings, + VoyageAIEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(modelId, TaskType.TEXT_EMBEDDING, service, serviceSettings, taskSettings, chunkingSettings), + new ModelSecrets(secretSettings), + secretSettings, + serviceSettings.getCommonSettings(), + url + ); + } + + private VoyageAIEmbeddingsModel(VoyageAIEmbeddingsModel model, VoyageAIEmbeddingsTaskSettings taskSettings) { + super(model, taskSettings); + } + + public VoyageAIEmbeddingsModel(VoyageAIEmbeddingsModel model, VoyageAIEmbeddingsServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + @Override + public VoyageAIEmbeddingsServiceSettings getServiceSettings() { + return (VoyageAIEmbeddingsServiceSettings) super.getServiceSettings(); + } + + @Override + public VoyageAIEmbeddingsTaskSettings getTaskSettings() { + return (VoyageAIEmbeddingsTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + @Override + public ExecutableAction accept(VoyageAIActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings, inputType); + } + + protected URI buildRequestUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(HOST) + .setPathSegments(VoyageAIUtils.VERSION_1, VoyageAIUtils.EMBEDDINGS_PATH) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..cc4db278d0e2b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java @@ -0,0 +1,259 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; + +public class VoyageAIEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { + public static final String NAME = "voyageai_embeddings_service_settings"; + static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; + public static final VoyageAIEmbeddingsServiceSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsServiceSettings( + null, + null, + null, + null, + null, + false + ); + + public static final String EMBEDDING_TYPE = "embedding_type"; + + public static VoyageAIEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + return switch (context) { + case REQUEST -> fromRequestMap(map, context); + case PERSISTENT -> fromPersistentMap(map, context); + }; + } + + private static VoyageAIEmbeddingsServiceSettings fromRequestMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + var commonServiceSettings = VoyageAIServiceSettings.fromMap(map, context); + + VoyageAIEmbeddingType embeddingTypes = parseEmbeddingType(map, context, validationException); + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new VoyageAIEmbeddingsServiceSettings(commonServiceSettings, embeddingTypes, similarity, dims, maxInputTokens, dims != null); + } + + private static VoyageAIEmbeddingsServiceSettings fromPersistentMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + var commonServiceSettings = VoyageAIServiceSettings.fromMap(map, context); + + VoyageAIEmbeddingType embeddingTypes = parseEmbeddingType(map, context, validationException); + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + + Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class); + if (dimensionsSetByUser == null) { + dimensionsSetByUser = Boolean.FALSE; + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new VoyageAIEmbeddingsServiceSettings( + commonServiceSettings, + embeddingTypes, + similarity, + dims, + maxInputTokens, + dimensionsSetByUser + ); + } + + static VoyageAIEmbeddingType parseEmbeddingType( + Map map, + ConfigurationParseContext context, + ValidationException validationException + ) { + return switch (context) { + case REQUEST, PERSISTENT -> Objects.requireNonNullElse( + extractOptionalEnum( + map, + EMBEDDING_TYPE, + ModelConfigurations.SERVICE_SETTINGS, + VoyageAIEmbeddingType::fromString, + EnumSet.allOf(VoyageAIEmbeddingType.class), + validationException + ), + VoyageAIEmbeddingType.FLOAT + ); + + }; + } + + private final VoyageAIServiceSettings commonSettings; + private final VoyageAIEmbeddingType embeddingType; + private final SimilarityMeasure similarity; + private final Integer dimensions; + private final Integer maxInputTokens; + private final boolean dimensionsSetByUser; + + public VoyageAIEmbeddingsServiceSettings( + VoyageAIServiceSettings commonSettings, + @Nullable VoyageAIEmbeddingType embeddingType, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + boolean dimensionsSetByUser + ) { + this.commonSettings = commonSettings; + this.similarity = similarity; + this.dimensions = dimensions; + this.maxInputTokens = maxInputTokens; + this.embeddingType = embeddingType; + this.dimensionsSetByUser = dimensionsSetByUser; + } + + public VoyageAIEmbeddingsServiceSettings(StreamInput in) throws IOException { + this.commonSettings = new VoyageAIServiceSettings(in); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.dimensions = in.readOptionalVInt(); + this.maxInputTokens = in.readOptionalVInt(); + this.embeddingType = Objects.requireNonNullElse(in.readOptionalEnum(VoyageAIEmbeddingType.class), VoyageAIEmbeddingType.FLOAT); + this.dimensionsSetByUser = in.readBoolean(); + } + + public VoyageAIServiceSettings getCommonSettings() { + return commonSettings; + } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public Integer dimensions() { + return dimensions; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public String modelId() { + return commonSettings.modelId(); + } + + public VoyageAIEmbeddingType getEmbeddingType() { + return embeddingType; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return embeddingType == null ? DenseVectorFieldMapper.ElementType.FLOAT : embeddingType.toElementType(); + } + + @Override + public Boolean dimensionsSetByUser() { + return this.dimensionsSetByUser; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder = commonSettings.toXContentFragment(builder, params); + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + if (embeddingType != null) { + builder.field(EMBEDDING_TYPE, embeddingType); + } + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + commonSettings.toXContentFragmentOfExposedFields(builder, params); + + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + commonSettings.writeTo(out); + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + out.writeOptionalVInt(dimensions); + out.writeOptionalVInt(maxInputTokens); + out.writeOptionalEnum(embeddingType); + out.writeBoolean(dimensionsSetByUser); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VoyageAIEmbeddingsServiceSettings that = (VoyageAIEmbeddingsServiceSettings) o; + return Objects.equals(commonSettings, that.commonSettings) + && Objects.equals(similarity, that.similarity) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(embeddingType, that.embeddingType) + && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser); + } + + @Override + public int hashCode() { + return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens, embeddingType, dimensionsSetByUser); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java new file mode 100644 index 0000000000000..5d8d282588349 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java @@ -0,0 +1,202 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; +import static org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields.TRUNCATION; + +/** + * Defines the task settings for the voyageai text embeddings service. + * + *

+ * See api docs for details. + *

+ */ +public class VoyageAIEmbeddingsTaskSettings implements TaskSettings { + + public static final String NAME = "voyageai_embeddings_task_settings"; + public static final VoyageAIEmbeddingsTaskSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsTaskSettings(null, null); + static final String INPUT_TYPE = "input_type"; + static final EnumSet VALID_REQUEST_VALUES = EnumSet.of(InputType.INGEST, InputType.SEARCH); + + public static VoyageAIEmbeddingsTaskSettings fromMap(Map map) { + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + ValidationException validationException = new ValidationException(); + + InputType inputType = extractOptionalEnum( + map, + INPUT_TYPE, + ModelConfigurations.TASK_SETTINGS, + InputType::fromString, + VALID_REQUEST_VALUES, + validationException + ); + Boolean truncation = extractOptionalBoolean(map, TRUNCATION, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new VoyageAIEmbeddingsTaskSettings(inputType, truncation); + } + + /** + * Creates a new {@link VoyageAIEmbeddingsTaskSettings} by preferring non-null fields from the provided parameters. + * For the input type, preference is given to requestInputType if it is not null and not UNSPECIFIED. + * Then preference is given to the requestTaskSettings and finally to originalSettings even if the value is null. + * Similarly, for the truncation field preference is given to requestTaskSettings if it is not null and then to + * originalSettings. + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @param requestInputType the input type passed in the request parameters + * @return a constructed {@link VoyageAIEmbeddingsTaskSettings} + */ + public static VoyageAIEmbeddingsTaskSettings of( + VoyageAIEmbeddingsTaskSettings originalSettings, + VoyageAIEmbeddingsTaskSettings requestTaskSettings, + InputType requestInputType + ) { + var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings, requestInputType); + var truncationToUse = getValidTruncation(originalSettings, requestTaskSettings); + + return new VoyageAIEmbeddingsTaskSettings(inputTypeToUse, truncationToUse); + } + + private static InputType getValidInputType( + VoyageAIEmbeddingsTaskSettings originalSettings, + VoyageAIEmbeddingsTaskSettings requestTaskSettings, + InputType requestInputType + ) { + InputType inputTypeToUse = originalSettings.inputType; + + if (VALID_REQUEST_VALUES.contains(requestInputType)) { + inputTypeToUse = requestInputType; + } else if (requestTaskSettings.inputType != null) { + inputTypeToUse = requestTaskSettings.inputType; + } + + return inputTypeToUse; + } + + private static Boolean getValidTruncation( + VoyageAIEmbeddingsTaskSettings originalSettings, + VoyageAIEmbeddingsTaskSettings requestTaskSettings + ) { + return requestTaskSettings.getTruncation() == null ? originalSettings.truncation : requestTaskSettings.getTruncation(); + } + + private final InputType inputType; + private final Boolean truncation; + + public VoyageAIEmbeddingsTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalEnum(InputType.class), in.readOptionalBoolean()); + } + + public VoyageAIEmbeddingsTaskSettings(@Nullable InputType inputType, @Nullable Boolean truncation) { + validateInputType(inputType); + this.inputType = inputType; + this.truncation = truncation; + } + + private static void validateInputType(InputType inputType) { + if (inputType == null) { + return; + } + + assert VALID_REQUEST_VALUES.contains(inputType) : invalidInputTypeMessage(inputType); + } + + @Override + public boolean isEmpty() { + return inputType == null && truncation == null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (inputType != null) { + builder.field(INPUT_TYPE, inputType); + } + + if (truncation != null) { + builder.field(TRUNCATION, truncation); + } + + builder.endObject(); + return builder; + } + + public InputType getInputType() { + return inputType; + } + + public Boolean getTruncation() { + return truncation; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalEnum(inputType); + out.writeOptionalBoolean(truncation); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VoyageAIEmbeddingsTaskSettings that = (VoyageAIEmbeddingsTaskSettings) o; + return Objects.equals(inputType, that.inputType) && Objects.equals(truncation, that.truncation); + } + + @Override + public int hashCode() { + return Objects.hash(inputType, truncation); + } + + public static String invalidInputTypeMessage(InputType inputType) { + return Strings.format("received invalid input type value [%s]", inputType.toString()); + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + VoyageAIEmbeddingsTaskSettings updatedSettings = VoyageAIEmbeddingsTaskSettings.fromMap(new HashMap<>(newSettings)); + return of(this, updatedSettings, updatedSettings.inputType != null ? updatedSettings.inputType : this.inputType); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java new file mode 100644 index 0000000000000..57c478962b5f2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.rerank; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionVisitor; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils.HOST; + +public class VoyageAIRerankModel extends VoyageAIModel { + public static VoyageAIRerankModel of(VoyageAIRerankModel model, Map taskSettings) { + var requestTaskSettings = VoyageAIRerankTaskSettings.fromMap(taskSettings); + return new VoyageAIRerankModel(model, VoyageAIRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + + public VoyageAIRerankModel( + String inferenceId, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceId, + service, + VoyageAIRerankServiceSettings.fromMap(serviceSettings, context), + VoyageAIRerankTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + VoyageAIRerankModel( + String modelId, + String service, + VoyageAIRerankServiceSettings serviceSettings, + VoyageAIRerankTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + this(modelId, service, null, serviceSettings, taskSettings, secretSettings); + } + + VoyageAIRerankModel( + String modelId, + String service, + String url, + VoyageAIRerankServiceSettings serviceSettings, + VoyageAIRerankTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(modelId, TaskType.RERANK, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + secretSettings, + serviceSettings.getCommonSettings(), + url + ); + } + + private VoyageAIRerankModel(VoyageAIRerankModel model, VoyageAIRerankTaskSettings taskSettings) { + super(model, taskSettings); + } + + public VoyageAIRerankModel(VoyageAIRerankModel model, VoyageAIRerankServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + @Override + public VoyageAIRerankServiceSettings getServiceSettings() { + return (VoyageAIRerankServiceSettings) super.getServiceSettings(); + } + + @Override + public VoyageAIRerankTaskSettings getTaskSettings() { + return (VoyageAIRerankTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + /** + * Accepts a visitor to create an executable action. The returned action will not return documents in the response. + * @param visitor _ + * @param taskSettings _ + * @param inputType ignored for rerank task + * @return the rerank action + */ + @Override + public ExecutableAction accept(VoyageAIActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings); + } + + @Override + protected URI buildRequestUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(HOST) + .setPathSegments(VoyageAIUtils.VERSION_1, VoyageAIUtils.RERANK_PATH) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettings.java new file mode 100644 index 0000000000000..1d3607922c5c2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettings.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.rerank; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class VoyageAIRerankServiceSettings extends FilteredXContentObject implements ServiceSettings, VoyageAIRateLimitServiceSettings { + public static final String NAME = "voyageai_rerank_service_settings"; + + private static final Logger logger = LogManager.getLogger(VoyageAIRerankServiceSettings.class); + + public static VoyageAIRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + var commonServiceSettings = VoyageAIServiceSettings.fromMap(map, context); + + return new VoyageAIRerankServiceSettings(commonServiceSettings); + } + + private final VoyageAIServiceSettings commonSettings; + + public VoyageAIRerankServiceSettings(VoyageAIServiceSettings commonSettings) { + this.commonSettings = commonSettings; + } + + public VoyageAIRerankServiceSettings(StreamInput in) throws IOException { + this.commonSettings = new VoyageAIServiceSettings(in); + } + + public VoyageAIServiceSettings getCommonSettings() { + return commonSettings; + } + + @Override + public String modelId() { + return commonSettings.modelId(); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return commonSettings.rateLimitSettings(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder = commonSettings.toXContentFragment(builder, params); + + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + commonSettings.toXContentFragmentOfExposedFields(builder, params); + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + commonSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VoyageAIRerankServiceSettings that = (VoyageAIRerankServiceSettings) o; + return Objects.equals(commonSettings, that.commonSettings); + } + + @Override + public int hashCode() { + return Objects.hash(commonSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java new file mode 100644 index 0000000000000..a5004fde1e17e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java @@ -0,0 +1,184 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields.TRUNCATION; + +/** + * Defines the task settings for the VoyageAI rerank service. + * + */ +public class VoyageAIRerankTaskSettings implements TaskSettings { + + public static final String NAME = "voyageai_rerank_task_settings"; + public static final String RETURN_DOCUMENTS = "return_documents"; + public static final String TOP_K_DOCS_ONLY = "top_k"; + + public static final VoyageAIRerankTaskSettings EMPTY_SETTINGS = new VoyageAIRerankTaskSettings(null, null, null); + + public static VoyageAIRerankTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException); + Integer topKDocumentsOnly = extractOptionalPositiveInteger( + map, + TOP_K_DOCS_ONLY, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + + Boolean truncation = extractOptionalBoolean(map, TRUNCATION, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return of(topKDocumentsOnly, returnDocuments, truncation); + } + + /** + * Creates a new {@link VoyageAIRerankTaskSettings} by preferring non-null fields from the request settings over the original settings. + * + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @return a constructed {@link VoyageAIRerankTaskSettings} + */ + public static VoyageAIRerankTaskSettings of( + VoyageAIRerankTaskSettings originalSettings, + VoyageAIRerankTaskSettings requestTaskSettings + ) { + return new VoyageAIRerankTaskSettings( + requestTaskSettings.getTopKDocumentsOnly() != null + ? requestTaskSettings.getTopKDocumentsOnly() + : originalSettings.getTopKDocumentsOnly(), + requestTaskSettings.getReturnDocuments() != null + ? requestTaskSettings.getReturnDocuments() + : originalSettings.getReturnDocuments(), + requestTaskSettings.getTruncation() != null ? requestTaskSettings.getTruncation() : originalSettings.getTruncation() + + ); + } + + public static VoyageAIRerankTaskSettings of(Integer topKDocumentsOnly, Boolean returnDocuments, Boolean truncation) { + return new VoyageAIRerankTaskSettings(topKDocumentsOnly, returnDocuments, truncation); + } + + private final Integer topKDocumentsOnly; + private final Boolean returnDocuments; + private final Boolean truncation; + + public VoyageAIRerankTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalInt(), in.readOptionalBoolean(), in.readOptionalBoolean()); + } + + public VoyageAIRerankTaskSettings( + @Nullable Integer topKDocumentsOnly, + @Nullable Boolean doReturnDocuments, + @Nullable Boolean truncation + ) { + this.topKDocumentsOnly = topKDocumentsOnly; + this.returnDocuments = doReturnDocuments; + this.truncation = truncation; + } + + @Override + public boolean isEmpty() { + return topKDocumentsOnly == null && returnDocuments == null && truncation == null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (topKDocumentsOnly != null) { + builder.field(TOP_K_DOCS_ONLY, topKDocumentsOnly); + } + if (returnDocuments != null) { + builder.field(RETURN_DOCUMENTS, returnDocuments); + } + if (truncation != null) { + builder.field(TRUNCATION, truncation); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalInt(topKDocumentsOnly); + out.writeOptionalBoolean(returnDocuments); + out.writeOptionalBoolean(truncation); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VoyageAIRerankTaskSettings that = (VoyageAIRerankTaskSettings) o; + return Objects.equals(topKDocumentsOnly, that.topKDocumentsOnly) + && Objects.equals(returnDocuments, that.returnDocuments) + && Objects.equals(truncation, that.truncation); + } + + @Override + public int hashCode() { + return Objects.hash(truncation, returnDocuments, topKDocumentsOnly); + } + + public Integer getTopKDocumentsOnly() { + return topKDocumentsOnly; + } + + public Boolean getDoesReturnDocuments() { + return returnDocuments; + } + + public Boolean getReturnDocuments() { + return returnDocuments; + } + + public Boolean getTruncation() { + return truncation; + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + VoyageAIRerankTaskSettings updatedSettings = VoyageAIRerankTaskSettings.fromMap(new HashMap<>(newSettings)); + return VoyageAIRerankTaskSettings.of(this, updatedSettings); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java new file mode 100644 index 0000000000000..50c64468d732a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java @@ -0,0 +1,145 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.voyageai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettingsTests; +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class VoyageAIActionCreatorTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testCreate_VoyageAIEmbeddingsModel() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data": [{ + "object": "embedding", + "embedding": [ + 0.123, + -0.123 + ], + "index": 0 + }], + "model": "voyage-3-large", + "usage": { + "total_tokens": 123 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true), + 1024, + 1024, + "model", + VoyageAIEmbeddingType.FLOAT + ); + var actionCreator = new VoyageAIActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH); + var action = actionCreator.create(model, overriddenTaskSettings, InputType.UNSPECIFIED); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "output_dtype", + "float", + "truncation", + true, + "input_type", + "query", + "output_dimension", + 1024, + "input", + List.of("abc"), + "model", + "model" + ) + ) + ); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java new file mode 100644 index 0000000000000..3b3397f702a64 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java @@ -0,0 +1,413 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.voyageai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationBinary; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationByte; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class VoyageAIEmbeddingsActionTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data": [{ + "object": "embedding", + "embedding": [ + 0.123, + -0.123 + ], + "index": 0 + }], + "model": "voyage-3-large", + "usage": { + "total_tokens": 123 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction( + getUrl(webServer), + "secret", + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true), + "model", + VoyageAIEmbeddingType.FLOAT, + sender + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER), + equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + equalTo( + Map.of( + "input", + List.of("abc"), + "model", + "model", + "input_type", + "document", + "output_dtype", + "float", + "truncation", + true, + "output_dimension", + 1024 + ) + ) + ); + } + } + + public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data": [{ + "object": "embedding", + "embedding": [ + 0, + -1 + ], + "index": 0 + }], + "model": "voyage-3-large", + "usage": { + "total_tokens": 123 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction( + getUrl(webServer), + "secret", + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true), + "model", + VoyageAIEmbeddingType.INT8, + sender + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationByte(List.of(new byte[] { 0, -1 })), result.asMap()); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER), + equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "input", + List.of("abc"), + "model", + "model", + "input_type", + "document", + "output_dtype", + "int8", + "truncation", + true, + "output_dimension", + 1024 + ) + ) + ); + } + } + + public void testExecute_ReturnsSuccessfulResponse_ForBinaryResponseType() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data": [{ + "object": "embedding", + "embedding": [ + 0, + -1 + ], + "index": 0 + }], + "model": "voyage-3-large", + "usage": { + "total_tokens": 123 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction( + getUrl(webServer), + "secret", + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true), + "model", + VoyageAIEmbeddingType.BINARY, + sender + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationBinary(List.of(new byte[] { 0, -1 })), result.asMap()); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER), + equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "input", + List.of("abc"), + "model", + "model", + "input_type", + "document", + "output_dtype", + "binary", + "truncation", + true, + "output_dimension", + 1024 + ) + ) + ); + } + } + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is(format("Failed to send VoyageAI embeddings request to [%s]", getUrl(webServer))) + ); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled_WhenUrlIsNull() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction(null, "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send VoyageAI embeddings request")); + } + + public void testExecute_ThrowsException() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is(format("Failed to send VoyageAI embeddings request to [%s]", getUrl(webServer))) + ); + } + + public void testExecute_ThrowsExceptionWithNullUrl() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(null, "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send VoyageAI embeddings request")); + } + + private ExecutableAction createAction( + String url, + String apiKey, + VoyageAIEmbeddingsTaskSettings taskSettings, + @Nullable String modelName, + @Nullable VoyageAIEmbeddingType embeddingType, + Sender sender + ) { + var model = VoyageAIEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024, modelName, embeddingType); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "VoyageAI embeddings"); + var requestCreator = VoyageAIEmbeddingsRequestManager.of(model, threadPool); + return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..66b2287e9cb50 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java @@ -0,0 +1,173 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.is; + +public class VoyageAIEmbeddingsRequestEntityTests extends ESTestCase { + public void testXContent_WritesAllFields_ServiceSettingsDefined() throws IOException { + var entity = new VoyageAIEmbeddingsRequestEntity( + List.of("abc"), + VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + "https://www.abc.com", + ServiceFields.SIMILARITY, + SimilarityMeasure.DOT_PRODUCT.toString(), + ServiceFields.DIMENSIONS, + 2048, + ServiceFields.MAX_INPUT_TOKENS, + 512, + VoyageAIServiceSettings.MODEL_ID, + "model", + VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE, + "float" + ) + ), + ConfigurationParseContext.PERSISTENT + ), + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"float"}""")); + } + + public void testXContent_WritesAllFields_ServiceSettingsDefined_Int8() throws IOException { + var entity = new VoyageAIEmbeddingsRequestEntity( + List.of("abc"), + VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + "https://www.abc.com", + ServiceFields.SIMILARITY, + SimilarityMeasure.DOT_PRODUCT.toString(), + ServiceFields.DIMENSIONS, + 2048, + ServiceFields.MAX_INPUT_TOKENS, + 512, + VoyageAIServiceSettings.MODEL_ID, + "model", + VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE, + "int8" + ) + ), + ConfigurationParseContext.PERSISTENT + ), + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"int8"}""")); + } + + public void testXContent_WritesAllFields_ServiceSettingsDefined_Binary() throws IOException { + var entity = new VoyageAIEmbeddingsRequestEntity( + List.of("abc"), + VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + "https://www.abc.com", + ServiceFields.SIMILARITY, + SimilarityMeasure.DOT_PRODUCT.toString(), + ServiceFields.DIMENSIONS, + 2048, + ServiceFields.MAX_INPUT_TOKENS, + 512, + VoyageAIServiceSettings.MODEL_ID, + "model", + VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE, + "binary" + ) + ), + ConfigurationParseContext.PERSISTENT + ), + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"binary"}""")); + } + + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new VoyageAIEmbeddingsRequestEntity( + List.of("abc"), + VoyageAIEmbeddingsServiceSettings.EMPTY_SETTINGS, + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","input_type":"document"}""")); + } + + public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { + var entity = new VoyageAIEmbeddingsRequestEntity( + List.of("abc"), + VoyageAIEmbeddingsServiceSettings.EMPTY_SETTINGS, + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model"}""")); + } + + public void testConvertToString_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { + var thrownException = expectThrows( + AssertionError.class, + () -> VoyageAIEmbeddingsRequestEntity.convertToString(InputType.UNSPECIFIED) + ); + MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..868849542457c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java @@ -0,0 +1,215 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class VoyageAIEmbeddingsRequestTests extends ESTestCase { + public void testCreateRequest_UrlDefined() throws IOException { + var request = createRequest( + List.of("abc"), + VoyageAIEmbeddingsModelTests.createModel("url", "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, "model") + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model", "output_dtype", "float"))); + } + + public void testCreateRequest_AllOptionsDefined() throws IOException { + var request = createRequest( + List.of("abc"), + VoyageAIEmbeddingsModelTests.createModel( + "url", + "secret", + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + null, + null, + "model" + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "model", "input_type", "document", "output_dtype", "float")) + ); + } + + public void testCreateRequest_DimensionDefined() throws IOException { + var request = createRequest( + List.of("abc"), + VoyageAIEmbeddingsModelTests.createModel( + "url", + "secret", + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + null, + 2048, + "model" + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "input", + List.of("abc"), + "model", + "model", + "input_type", + "document", + "output_dtype", + "float", + "output_dimension", + 2048 + ) + ) + ); + } + + public void testCreateRequest_EmbeddingTypeDefined() throws IOException { + var request = createRequest( + List.of("abc"), + VoyageAIEmbeddingsModelTests.createModel( + "url", + "secret", + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + null, + 2048, + "model", + VoyageAIEmbeddingType.BYTE + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "input", + List.of("abc"), + "model", + "model", + "input_type", + "document", + "output_dtype", + "int8", + "output_dimension", + 2048 + ) + ) + ); + } + + public void testCreateRequest_InputTypeSearch() throws IOException { + var request = createRequest( + List.of("abc"), + VoyageAIEmbeddingsModelTests.createModel( + "url", + "secret", + new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null), + null, + null, + "model" + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "model", "input_type", "query", "output_dtype", "float")) + ); + } + + public static VoyageAIEmbeddingsRequest createRequest(List input, VoyageAIEmbeddingsModel model) { + return new VoyageAIEmbeddingsRequest(input, model); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java new file mode 100644 index 0000000000000..5244ef83a7c7b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount; + +import java.net.URI; + +import static org.hamcrest.Matchers.is; + +public class VoyageAIRequestTests extends ESTestCase { + + public void testDecorateWithHeaders() { + var request = new HttpPost("http://www.abc.com"); + + VoyageAIRequest.decorateWithHeaders( + request, + new VoyageAIAccount(URI.create("http://www.abc.com"), new SecureString(new char[] { 'a', 'b', 'c' })) + ); + + assertThat(request.getFirstHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(request.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer abc")); + assertThat(request.getFirstHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(), is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java new file mode 100644 index 0000000000000..ae431b4b7bb13 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java @@ -0,0 +1,186 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class VoyageAIRerankRequestEntityTests extends ESTestCase { + public void testXContent_SingleRequest_WritesModelAndTopKIfDefined() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, null, null), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ], + "top_k": 8 + } + """)); + } + + public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsTrue() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, true, null), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ], + "return_documents": true, + "top_k": 8 + } + """)); + } + + public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsFalse() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, null), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ], + "return_documents": false, + "top_k": 8 + } + """)); + } + + public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationTrue() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, true), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ], + "return_documents": false, + "top_k": 8, + "truncation": true + } + """)); + } + + public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationFalse() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, false), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ], + "return_documents": false, + "top_k": 8, + "truncation": false + } + """)); + } + + public void testXContent_SingleRequest_DoesNotWriteTopKIfNull() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), null, "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ] + } + """)); + } + + public void testXContent_MultipleRequests_WritesModelAndTopKIfDefined() throws IOException { + var entity = new VoyageAIRerankRequestEntity( + "query", + List.of("abc", "def"), + new VoyageAIRerankTaskSettings(8, null, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc", + "def" + ], + "top_k": 8 + } + """)); + } + + public void testXContent_MultipleRequests_DoesNotWriteTopKIfNull() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc", "def"), null, "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc", + "def" + ] + } + """)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java new file mode 100644 index 0000000000000..a11d259200b98 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class VoyageAIRerankRequestTests extends ESTestCase { + + private static final String API_KEY = "foo"; + + public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException { + var input = "input"; + var query = "query"; + var modelId = "model"; + + var request = createRequest(query, input, modelId, null); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("documents"), is(List.of(input))); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("model"), is(modelId)); + } + + public void testCreateRequest_WithTopNSet() throws IOException { + var input = "input"; + var query = "query"; + var topK = 1; + var modelId = "model"; + + var request = createRequest(query, input, modelId, topK); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap, aMapWithSize(4)); + assertThat(requestMap.get("documents"), is(List.of(input))); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("top_k"), is(topK)); + assertThat(requestMap.get("model"), is(modelId)); + } + + public void testCreateRequest_WithModelSet() throws IOException { + var input = "input"; + var query = "query"; + var modelId = "model"; + + var request = createRequest(query, input, modelId, null); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("documents"), is(List.of(input))); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("model"), is(modelId)); + } + + public void testTruncate_DoesNotTruncate() { + var request = createRequest("query", "input", "null", null); + var truncatedRequest = request.truncate(); + + assertThat(truncatedRequest, sameInstance(request)); + } + + private static VoyageAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topK) { + var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, topK); + return new VoyageAIRerankRequest(query, List.of(input), rerankModel); + + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtilsTests.java new file mode 100644 index 0000000000000..4a8271f4fac88 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtilsTests.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.is; + +public class VoyageAIUtilsTests extends ESTestCase { + + public void testCreateRequestSourceHeader() { + var requestSourceHeader = VoyageAIUtils.createRequestSourceHeader(); + + assertThat(requestSourceHeader.getName(), is("Request-Source")); + assertThat(requestSourceHeader.getValue(), is("unspecified:elasticsearch")); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java new file mode 100644 index 0000000000000..2b1c8fa43af53 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java @@ -0,0 +1,432 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.voyageai; + +import org.apache.http.HttpResponse; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests.createModel; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class VoyageAIEmbeddingsResponseEntityTests extends ESTestCase { + public void testFromResponse_CreatesResultsForASingleItem() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + } + ], + "model": "voyage-3-large", + "usage": { + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel("url", "api_key", null, "voyage-3-large") + ); + + InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + ((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(), + is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.014539449F, -0.015288644F }))) + ); + } + + public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + }, + { + "object": "embedding", + "index": 1, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "voyage-3-large", + "usage": { + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel("url", "api_key", null, "voyage-3-large") + ); + + InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + ((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(), + is( + List.of( + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.014539449F, -0.015288644F }), + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123F, -0.0123F }) + ) + ) + ); + } + + public void testFromResponse_FailsWhenDataFieldIsNotPresent() { + String responseJson = """ + { + "object": "list", + "not_data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + } + ], + "model": "voyage-3-large", + "usage": { + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel("url", "api_key", null, "voyage-3-large") + ); + + var thrownException = expectThrows( + java.lang.IllegalArgumentException.class, + () -> VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("Required [data]")); + } + + public void testFromResponse_FailsWhenDataFieldNotAnArray() { + String responseJson = """ + { + "object": "list", + "data": { + "test": { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + } + }, + "model": "voyage-3-large", + "usage": { + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel("url", "api_key", null, "voyage-3-large") + ); + + var thrownException = expectThrows( + XContentParseException.class, + () -> VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), containsString("[EmbeddingFloatResult] failed to parse field [data]")); + } + + public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embeddingzzz": [ + 0.014539449, + -0.015288644 + ] + } + ], + "model": "voyage-3-large", + "usage": { + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel("url", "api_key", null, "voyage-3-large") + ); + + var thrownException = expectThrows( + XContentParseException.class, + () -> VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), containsString("[EmbeddingFloatResult] failed to parse field [data]")); + } + + public void testFromResponse_FailsWhenEmbeddingValueIsAString() { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + "abc" + ] + } + ], + "model": "voyage-3-large", + "usage": { + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel("url", "api_key", null, "voyage-3-large") + ); + + var thrownException = expectThrows( + XContentParseException.class, + () -> VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("[8:15] [EmbeddingFloatResult] failed to parse field [data]")); + } + + public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 1 + ] + } + ], + "model": "voyage-3-large", + "usage": { + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel("url", "api_key", null, "voyage-3-large") + ); + + InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + ((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(), + is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 1.0F }))) + ); + } + + public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 40294967295 + ] + } + ], + "model": "voyage-3-large", + "usage": { + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel("url", "api_key", null, "voyage-3-large") + ); + + InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + ((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(), + is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 4.0294965E10F }))) + ); + } + + public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + {} + ] + } + ], + "model": "voyage-3-large", + "usage": { + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel("url", "api_key", null, "voyage-3-large") + ); + + var thrownException = expectThrows( + XContentParseException.class, + () -> VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("[8:15] [EmbeddingFloatResult] failed to parse field [data]")); + } + + public void testFieldsInDifferentOrderServer() throws IOException { + // The fields of the objects in the data array are reordered + String response = """ + { + "object": "list", + "model": "voyage-3-large", + "data": [ + { + "embedding": [ + -0.9, + 0.5, + 0.3 + ], + "index": 0, + "object": "embedding" + }, + { + "index": 0, + "embedding": [ + 0.1, + 0.5 + ], + "object": "embedding" + }, + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.5, + 0.5 + ] + } + ], + "usage": { + "total_tokens": 0 + } + }"""; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel("url", "api_key", null, "voyage-3-large") + ); + + InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults, instanceOf(InferenceTextEmbeddingFloatResults.class)); + + assertThat( + ((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(), + is( + List.of( + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { -0.9F, 0.5F, 0.3F }), + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.1F, 0.5F }), + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.5F, 0.5F }) + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntityTests.java new file mode 100644 index 0000000000000..bb5c1da90d776 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntityTests.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.voyageai; + +import org.apache.http.HttpResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.hamcrest.MatcherAssert; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class VoyageAIErrorResponseEntityTests extends ESTestCase { + public void testFromResponse() { + String message = "\"input\" length 2049 is larger than the largest allowed size 2048"; + String escapedMessage = message.replace("\\", "\\\\").replace("\"", "\\\""); + String responseJson = Strings.format(""" + { + "detail": "%s" + } + """, escapedMessage); + + ErrorResponse errorResponse = VoyageAIErrorResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + assertNotNull(errorResponse); + MatcherAssert.assertThat(errorResponse.getErrorMessage(), is(message)); + } + + public void testFromResponse_noMessage() { + String responseJson = """ + { + "error": "abc" + } + """; + + ErrorResponse errorResponse = VoyageAIErrorResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + MatcherAssert.assertThat(errorResponse, is(ErrorResponse.UNDEFINED_ERROR)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java new file mode 100644 index 0000000000000..5c7aa7a80ad8f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java @@ -0,0 +1,173 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.voyageai; + +import org.apache.http.HttpResponse; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class VoyageAIRerankResponseEntityTests extends ESTestCase { + + public void testResponseLiteral() throws IOException { + String responseLiteral = """ + { + "object": "list", + "model": "model", + "data": [ + { + "index": 2, + "relevance_score": 0.98005307 + }, + { + "index": 3, + "relevance_score": 0.27904198 + }, + { + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + InferenceServiceResults parsedResults = VoyageAIRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseLiteral.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + List expected = responseLiteralDocs(); + for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) { + assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index()); + } + } + + public void testGeneratedResponse() throws IOException { + int numDocs = randomIntBetween(1, 10); + + List expected = new ArrayList<>(numDocs); + StringBuilder responseBuilder = new StringBuilder(); + + responseBuilder.append("{"); + responseBuilder.append("\"model\": \"model\","); + responseBuilder.append("\"object\": \"list\","); + responseBuilder.append("\"data\": ["); + List indices = linear(numDocs); + List scores = linearFloats(numDocs); + for (int i = 0; i < numDocs; i++) { + int index = indices.remove(randomInt(indices.size() - 1)); + + responseBuilder.append("{"); + responseBuilder.append("\"index\":").append(index).append(","); + responseBuilder.append("\"relevance_score\":").append(scores.get(i).toString()).append("}"); + expected.add(new RankedDocsResults.RankedDoc(index, scores.get(i), null)); + if (i < numDocs - 1) { + responseBuilder.append(","); + } + } + responseBuilder.append("],"); + responseBuilder.append("\"usage\": {"); + responseBuilder.append("\"total_tokens\": 15}"); + responseBuilder.append("}"); + + InferenceServiceResults parsedResults = VoyageAIRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseBuilder.toString().getBytes(StandardCharsets.UTF_8)) + ); + MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) { + assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index()); + } + } + + private ArrayList responseLiteralDocs() { + var list = new ArrayList(); + + list.add(new RankedDocsResults.RankedDoc(2, 0.98005307F, null)); + list.add(new RankedDocsResults.RankedDoc(3, 0.27904198F, null)); + list.add(new RankedDocsResults.RankedDoc(0, 0.10194652F, null)); + return list; + } + + public void testResponseLiteralWithDocuments() throws IOException { + String responseLiteralWithDocuments = """ + { + "object": "list", + "model": "model", + "data": [ + { + "document": "Washington, D.C..", + "index": 2, + "relevance_score": 0.98005307 + }, + { + "document": "Capital punishment has existed in the United States since beforethe United States was a country. ", + "index": 3, + "relevance_score": 0.27904198 + }, + { + "document": "Carson City is the capital city of the American state of Nevada.", + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + InferenceServiceResults parsedResults = VoyageAIRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + MatcherAssert.assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText)); + } + + private final List responseLiteralDocsWithText = List.of( + new RankedDocsResults.RankedDoc(2, 0.98005307F, "Washington, D.C.."), + new RankedDocsResults.RankedDoc( + 3, + 0.27904198F, + "Capital punishment has existed in the United States since beforethe United States was a country. " + ), + new RankedDocsResults.RankedDoc(0, 0.10194652F, "Carson City is the capital city of the American state of Nevada.") + ); + + private ArrayList linear(int n) { + ArrayList list = new ArrayList<>(); + for (int i = 0; i <= n; i++) { + list.add(i); + } + return list; + } + + // creates a list of doubles of monotonically decreasing magnitude + private ArrayList linearFloats(int n) { + ArrayList list = new ArrayList<>(); + float startValue = 1.0f; + float decrement = startValue / n + 1; + for (int i = 0; i <= n; i++) { + list.add(startValue - (i * decrement)); + } + return list; + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandlerTests.java new file mode 100644 index 0000000000000..d032116d9a894 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandlerTests.java @@ -0,0 +1,138 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.voyageai; + +import org.apache.http.Header; +import org.apache.http.HeaderElement; +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.hamcrest.MatcherAssert; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.core.Is.is; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class VoyageAIResponseHandlerTests extends ESTestCase { + public void testCheckForFailureStatusCode_DoesNotThrowForStatusCodesBetween200And299() { + callCheckForFailureStatusCode(randomIntBetween(200, 299), "id"); + } + + public void testCheckForFailureStatusCode_ThrowsFor503() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(503, "id")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [503]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor500_WithShouldRetryTrue() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(500, "id")); + assertTrue(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [500]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor429_WithShouldRetryTrue() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(429, "id")); + assertTrue(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a rate limit status code for request from inference entity id [id] status [429]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS)); + } + + public void testCheckForFailureStatusCode_ThrowsFor400() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(400, "id")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received an input validation error response for request from inference entity id [id] status [400]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor400_InputsTooLarge() { + var exception = expectThrows( + RetryException.class, + () -> callCheckForFailureStatusCode(400, "\"input\" length 2049 is larger than the largest allowed size 2048", "id") + ); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received an input validation error response for request from inference entity id [id] status [400]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor401() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(401, "inferenceEntityId")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString( + "Received an authentication error status code for request from inference entity id [inferenceEntityId] status [401]" + ) + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.UNAUTHORIZED)); + } + + public void testCheckForFailureStatusCode_ThrowsFor402() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(402, "inferenceEntityId")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat(exception.getCause().getMessage(), containsString("Payment required")); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.PAYMENT_REQUIRED)); + } + + private static void callCheckForFailureStatusCode(int statusCode, String modelId) { + callCheckForFailureStatusCode(statusCode, null, modelId); + } + + private static void callCheckForFailureStatusCode(int statusCode, @Nullable String errorMessage, String modelId) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + var header = mock(Header.class); + when(header.getElements()).thenReturn(new HeaderElement[] {}); + when(httpResponse.getFirstHeader(anyString())).thenReturn(header); + + String escapedErrorMessage = errorMessage != null ? errorMessage.replace("\\", "\\\\").replace("\"", "\\\"") : errorMessage; + + String responseJson = Strings.format(""" + { + "detail": "%s" + } + """, escapedErrorMessage); + + var mockRequest = mock(Request.class); + when(mockRequest.getInferenceEntityId()).thenReturn(modelId); + var httpResult = new HttpResult(httpResponse, errorMessage == null ? new byte[] {} : responseJson.getBytes(StandardCharsets.UTF_8)); + var handler = new VoyageAIResponseHandler("", (request, result) -> null); + + handler.checkForFailureStatusCode(mockRequest, httpResult); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index 56bd690a9cdbf..09b73dc260693 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -146,4 +146,7 @@ public static Map buildExpectationByte(List embeddings) ); } + public static Map buildExpectationBinary(List embeddings) { + return Map.of("text_embedding_bits", embeddings.stream().map(InferenceByteEmbedding::new).toList()); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java new file mode 100644 index 0000000000000..09d890bd21f67 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class VoyageAIServiceSettingsTests extends AbstractWireSerializingTestCase { + + public static VoyageAIServiceSettings createRandomWithNonNullUrl() { + return createRandom(); + } + + /** + * The created settings can have a url set to null. + */ + public static VoyageAIServiceSettings createRandom() { + var model = randomAlphaOfLength(15); + + return new VoyageAIServiceSettings(model, RateLimitSettingsTests.createRandom()); + } + + public void testFromMap() { + var model = "model"; + var serviceSettings = VoyageAIServiceSettings.fromMap( + new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, model)), + ConfigurationParseContext.REQUEST + ); + + MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(model, null))); + } + + public void testFromMap_WithRateLimit() { + var model = "model"; + var serviceSettings = VoyageAIServiceSettings.fromMap( + new HashMap<>( + Map.of( + VoyageAIServiceSettings.MODEL_ID, + model, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3)) + ) + ), + ConfigurationParseContext.REQUEST + ); + + MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(model, new RateLimitSettings(3)))); + } + + public void testFromMap_WhenUsingModelId() { + var model = "model"; + var serviceSettings = VoyageAIServiceSettings.fromMap( + new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, model)), + ConfigurationParseContext.PERSISTENT + ); + + MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(model, null))); + } + + public void testXContent_WritesModelId() throws IOException { + var entity = new VoyageAIServiceSettings("model", new RateLimitSettings(1)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"model_id":"model","rate_limit":{"requests_per_minute":1}}""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return VoyageAIServiceSettings::new; + } + + @Override + protected VoyageAIServiceSettings createTestInstance() { + return createRandomWithNonNullUrl(); + } + + @Override + protected VoyageAIServiceSettings mutateInstance(VoyageAIServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, VoyageAIServiceSettingsTests::createRandom); + } + + public static Map getServiceSettingsMap(String model) { + var map = new HashMap(); + + map.put(VoyageAIServiceSettings.MODEL_ID, model); + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java new file mode 100644 index 0000000000000..d45c78002fa4d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -0,0 +1,1971 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModelTests; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class VoyageAIServiceTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModel() throws IOException { + try (var service = createVoyageAIService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), + getSecretSettingsMap("secret") + ), + modelListener + ); + + } + } + + public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + try (var service = createVoyageAIService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelListener + ); + + } + } + + public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + try (var service = createVoyageAIService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), + getSecretSettingsMap("secret") + ), + modelListener + ); + + } + } + + public void testParseRequestConfig_OptionalTaskSettings() throws IOException { + try (var service = createVoyageAIService()) { + + ActionListener modelListener = ActionListener.wrap(model -> { + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), equalTo(VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap(VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), getSecretSettingsMap("secret")), + modelListener + ); + + } + } + + public void testParseRequestConfig_ThrowsUnsupportedTaskType() throws IOException { + try (var service = createVoyageAIService()) { + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "The [voyageai] service does not support task type [sparse_embedding]" + ); + + service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ), + failureListener + ); + } + } + + private static ActionListener getModelListenerForException(Class exceptionClass, String expectedMessage) { + return ActionListener.wrap((model) -> fail("Model parsing should have failed"), e -> { + MatcherAssert.assertThat(e, instanceOf(exceptionClass)); + MatcherAssert.assertThat(e.getMessage(), is(expectedMessage)); + }); + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createVoyageAIService()) { + var config = getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ); + config.put("extra_key", "value"); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createVoyageAIService()) { + var serviceSettings = VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + serviceSettings, + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + try (var service = createVoyageAIService()) { + var taskSettingsMap = VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST); + taskSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + taskSettingsMap, + getSecretSettingsMap("secret") + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + try (var service = createVoyageAIService()) { + var secretSettingsMap = getSecretSettingsMap("secret"); + secretSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + secretSettingsMap + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModel() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("oldmodel"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [voyageai] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH), + getSecretSettingsMap("secret") + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + try (var service = createVoyageAIService()) { + var secretSettingsMap = getSecretSettingsMap("secret"); + secretSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + secretSettingsMap + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ); + persistedConfig.secrets().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createVoyageAIService()) { + var serviceSettingsMap = VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createVoyageAIService()) { + var taskSettingsMap = VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH); + taskSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + taskSettingsMap, + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModel() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); + assertNull(embeddingsModel.getSecretSettings()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model_old"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty() + ); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [voyageai] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWithoutUrl() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty() + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); + assertNull(embeddingsModel.getSecretSettings()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createVoyageAIService()) { + var serviceSettingsMap = VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null))); + assertNull(embeddingsModel.getSecretSettings()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createVoyageAIService()) { + var taskSettingsMap = VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST); + taskSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), + taskSettingsMap + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null))); + assertNull(embeddingsModel.getSecretSettings()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotVoyageAIModel() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name"); + + try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testCheckModelConfig_UpdatesDimensions() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "voyage-3-large", + "object": "list", + "usage": { + "total_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 1, + "voyage-3-large" + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat( + result, + // the dimension is set to 2 because there are 2 embeddings returned from the mock server + is( + VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 2, + "voyage-3-large" + ) + ) + ); + } + } + + public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "voyage-3-large", + "object": "list", + "usage": { + "total_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 1, + "voyage-3-large", + (SimilarityMeasure) null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat( + result, + // the dimension is set to 2 because there are 2 embeddings returned from the mock server + is( + VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 2, + "voyage-3-large", + SimilarityMeasure.DOT_PRODUCT + ) + ) + ); + } + } + + public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "voyage-3-large", + "object": "list", + "usage": { + "total_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 1, + "voyage-3-large", + SimilarityMeasure.COSINE + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat( + result, + // the dimension is set to 2 because there are 2 embeddings returned from the mock server + is( + VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 2, + "voyage-3-large", + SimilarityMeasure.COSINE + ) + ) + ); + } + } + + public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { + testUpdateModelWithEmbeddingDetails_Successful(null); + } + + public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException { + testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values())); + } + + private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + var embeddingSize = randomNonNegativeInt(); + var model = VoyageAIEmbeddingsModelTests.createModel( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + randomNonNegativeInt(), + randomNonNegativeInt(), + randomAlphaOfLength(10), + similarityMeasure + ); + + Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); + + SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null + ? VoyageAIService.defaultSimilarity() + : similarityMeasure; + assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity()); + assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); + } + } + + public void testInfer_Embedding_UnauthorisedResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "detail": "Unauthorized" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "model", + (SimilarityMeasure) null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + MatcherAssert.assertThat(error.getMessage(), containsString("Error message: [Unauthorized]")); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + } + } + + public void testInfer_Rerank_UnauthorisedResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "detail": "Unauthorized" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "model", 1024, false, false); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + MatcherAssert.assertThat(error.getMessage(), containsString("Error message: [Unauthorized]")); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + } + } + + public void testInfer_Embedding_Get_Response_Ingest() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "voyage-3-large", + "object": "list", + "usage": { + "total_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "voyage-3-large", + (SimilarityMeasure) null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "input", + List.of("abc"), + "model", + "voyage-3-large", + "input_type", + "document", + "output_dtype", + "float", + "output_dimension", + 1024 + ) + ) + ); + } + } + + public void testInfer_Embedding_Get_Response_Search() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "voyage-3-large", + "object": "list", + "usage": { + "total_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "voyage-3-large", + (SimilarityMeasure) null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "input", + List.of("abc"), + "model", + "voyage-3-large", + "input_type", + "query", + "output_dtype", + "float", + "output_dimension", + 1024 + ) + ) + ); + } + } + + public void testInfer_Embedding_Get_Response_clustering() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + {"model":"voyage-3-large","object":"list","usage":{"total_tokens":5}, + "data":[{"object":"embedding","index":0,"embedding":[0.123, -0.123]}]} + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "voyage-3-large", + (SimilarityMeasure) null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.CLUSTERING, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "voyage-3-large", "output_dtype", "float", "output_dimension", 1024)) + ); + } + } + + public void testInfer_Embedding_Get_Response_NullInputType() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "voyage-3-large", + "object": "list", + "usage": { + "total_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "voyage-3-large", + (SimilarityMeasure) null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "voyage-3-large", "output_dtype", "float", "output_dimension", 1024)) + ); + } + } + + public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOException { + String responseJson = """ + { + "model": "model", + "object": "list", + "data": [ + { + "index": 2, + "relevance_score": 0.98005307 + }, + { + "index": 1, + "relevance_score": 0.27904198 + }, + { + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false, false); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2", "candidate3"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("index", 2, "relevance_score", 0.98005307F)), + Map.of("ranked_doc", Map.of("index", 1, "relevance_score", 0.27904198F)), + Map.of("ranked_doc", Map.of("index", 0, "relevance_score", 0.10194652F)) + ) + ) + ) + ); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "query", + "query", + "documents", + List.of("candidate1", "candidate2", "candidate3"), + "model", + "model", + "return_documents", + false, + "truncation", + false + ) + ) + ); + + } + } + + public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOException { + String responseJson = """ + { + "object": "list", + "model": "model", + "data": [ + { + "index": 2, + "relevance_score": 0.98005307 + }, + { + "index": 1, + "relevance_score": 0.27904198 + }, + { + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false, false); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("index", 2, "relevance_score", 0.98005307F)), + Map.of("ranked_doc", Map.of("index", 1, "relevance_score", 0.27904198F)), + Map.of("ranked_doc", Map.of("index", 0, "relevance_score", 0.10194652F)) + ) + ) + ) + ); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "query", + "query", + "documents", + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + "model", + "model", + "return_documents", + false, + "top_k", + 3, + "truncation", + false + ) + ) + ); + + } + + } + + public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IOException { + String responseJson = """ + { + "object": "list", + "model": "model", + "data": [ + { + "index": 2, + "relevance_score": 0.98005307, + "document": "candidate3" + }, + { + "index": 1, + "relevance_score": 0.27904198, + "document": "candidate2" + }, + { + "index": 0, + "relevance_score": 0.10194652, + "document": "candidate1" + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null, null); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2", "candidate3"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("text", "candidate3", "index", 2, "relevance_score", 0.98005307F)), + Map.of("ranked_doc", Map.of("text", "candidate2", "index", 1, "relevance_score", 0.27904198F)), + Map.of("ranked_doc", Map.of("text", "candidate1", "index", 0, "relevance_score", 0.10194652F)) + ) + ) + ) + ); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is(Map.of("query", "query", "documents", List.of("candidate1", "candidate2", "candidate3"), "model", "model")) + ); + + } + + } + + public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOException { + String responseJson = """ + { + "object": "list", + "model": "model", + "data": [ + { + "index": 2, + "relevance_score": 0.98005307, + "document": "candidate3" + }, + { + "index": 1, + "relevance_score": 0.27904198, + "document": "candidate2" + }, + { + "index": 0, + "relevance_score": 0.10194652, + "document": "candidate1" + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true, true); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("text", "candidate3", "index", 2, "relevance_score", 0.98005307F)), + Map.of("ranked_doc", Map.of("text", "candidate2", "index", 1, "relevance_score", 0.27904198F)), + Map.of("ranked_doc", Map.of("text", "candidate1", "index", 0, "relevance_score", 0.10194652F)) + ) + ) + ) + ); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "query", + "query", + "documents", + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + "model", + "model", + "return_documents", + true, + "top_k", + 3, + "truncation", + true + ) + ) + ); + + } + + } + + public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "voyage-3-large", + "object": "list", + "usage": { + "total_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + 1024, + 1024, + "voyage-3-large", + (SimilarityMeasure) null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "input", + List.of("abc"), + "model", + "voyage-3-large", + "input_type", + "document", + "output_dtype", + "float", + "output_dimension", + 1024 + ) + ) + ); + } + } + + public void test_Embedding_ChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException { + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new VoyageAIEmbeddingsTaskSettings((InputType) null, null), + createRandomChunkingSettings(), + 1024, + 1024, + "voyage-3-large" + ); + + test_Embedding_ChunkedInfer_BatchesCalls(model); + } + + public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOException { + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new VoyageAIEmbeddingsTaskSettings((InputType) null, null), + null, + 1024, + 1024, + "voyage-3-large" + ); + + test_Embedding_ChunkedInfer_BatchesCalls(model); + } + + private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel model) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + // Batching will call the service with 2 input + String responseJson = """ + { + "model": "voyage-3-large", + "object": "list", + "usage": { + "total_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + }, + { + "object": "embedding", + "index": 1, + "embedding": [ + 0.223, + -0.223 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + // 2 input + service.chunkedInfer( + model, + null, + List.of("foo", "bar"), + new HashMap<>(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(2)); + { + assertThat(results.getFirst(), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.getFirst(); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals("foo", floatResult.chunks().getFirst().matchedText()); + assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().getFirst().embedding(), 0.0f); + } + { + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals("bar", floatResult.chunks().getFirst().matchedText()); + assertArrayEquals(new float[] { 0.223f, -0.223f }, floatResult.chunks().getFirst().embedding(), 0.0f); + } + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("foo", "bar"), "model", "voyage-3-large", "output_dtype", "float", "output_dimension", 1024)) + ); + } + } + + public void testDefaultSimilarity() { + assertEquals(SimilarityMeasure.DOT_PRODUCT, VoyageAIService.defaultSimilarity()); + } + + @SuppressWarnings("checkstyle:LineLength") + public void testGetConfiguration() throws Exception { + try (var service = createVoyageAIService()) { + String content = XContentHelper.stripWhitespace(""" + { + "service": "voyageai", + "name": "Voyage AI", + "task_types": ["text_embedding", "rerank"], + "configurations": { + "model_id": { + "description": "The name of the model to use for the inference task.", + "label": "Model ID", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "rerank"] + }, + "api_key": { + "description": "API Key for the provider you're connecting to.", + "label": "API Key", + "required": true, + "sensitive": true, + "updatable": true, + "type": "str", + "supported_task_types": ["text_embedding", "rerank"] + }, + "rate_limit.requests_per_minute": { + "description": "Minimize the number of rate limit errors.", + "label": "Rate Limit", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "rerank"] + } + } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = service.getConfiguration(); + assertToXContentEquivalent( + originalBytes, + toXContent(serviceConfiguration, XContentType.JSON, humanReadable), + XContentType.JSON + ); + } + } + + public void testDoesNotSupportsStreaming() throws IOException { + try (var service = new VoyageAIService(mock(), createWithEmptySettings(mock()))) { + assertFalse(service.canStream(TaskType.COMPLETION)); + assertFalse(service.canStream(TaskType.ANY)); + } + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>( + Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) + ); + } + + private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings)); + } + + private VoyageAIService createVoyageAIService() { + return new VoyageAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java new file mode 100644 index 0000000000000..e497e606b0689 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java @@ -0,0 +1,209 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.embeddings; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; +import org.hamcrest.MatcherAssert; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap; +import static org.hamcrest.Matchers.is; + +public class VoyageAIEmbeddingsModelTests extends ESTestCase { + + public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty_AndInputTypeIsInvalid() { + var model = createModel("url", "api_key", null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, Map.of(), InputType.UNSPECIFIED); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull_AndInputTypeIsInvalid() { + var model = createModel("url", "api_key", null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, null, InputType.UNSPECIFIED); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public void testOverrideWith_SetsInputTypeToIngest_WhenTheFieldIsNullInModelTaskSettings_AndNullInRequestTaskSettings() { + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.INGEST); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingStoredTaskSettings() { + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.SEARCH); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingRequestTaskSettings() { + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.INGEST), InputType.SEARCH); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_WhenRequestInputTypeIsInvalid() { + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.SEARCH), InputType.UNSPECIFIED); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_DoesNotSetInputType_FromRequest_IfInputTypeIsInvalid() { + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_DoesNotSetInputType_WhenRequestTaskSettingsIsNull_AndRequestInputTypeIsInvalid() { + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public static VoyageAIEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit, @Nullable String model) { + return createModel(url, apiKey, VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null, model); + } + + public static VoyageAIEmbeddingsModel createModel( + String url, + String apiKey, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model + ) { + return createModel(url, apiKey, VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions, model); + } + + public static VoyageAIEmbeddingsModel createModel( + String url, + String apiKey, + VoyageAIEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model + ) { + return new VoyageAIEmbeddingsModel( + "id", + "service", + url, + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(model, null), + VoyageAIEmbeddingType.FLOAT, + SimilarityMeasure.DOT_PRODUCT, + dimensions, + tokenLimit, + false + ), + taskSettings, + chunkingSettings, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static VoyageAIEmbeddingsModel createModel( + String url, + String apiKey, + VoyageAIEmbeddingsTaskSettings taskSettings, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model + ) { + return new VoyageAIEmbeddingsModel( + "id", + "service", + url, + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(model, null), + VoyageAIEmbeddingType.FLOAT, + SimilarityMeasure.DOT_PRODUCT, + dimensions, + tokenLimit, + false + ), + taskSettings, + null, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static VoyageAIEmbeddingsModel createModel( + String url, + String apiKey, + VoyageAIEmbeddingsTaskSettings taskSettings, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model, + VoyageAIEmbeddingType embeddingType + ) { + return new VoyageAIEmbeddingsModel( + "id", + "service", + url, + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(model, null), + embeddingType, + SimilarityMeasure.DOT_PRODUCT, + dimensions, + tokenLimit, + false + ), + taskSettings, + null, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static VoyageAIEmbeddingsModel createModel( + String url, + String apiKey, + VoyageAIEmbeddingsTaskSettings taskSettings, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model, + @Nullable SimilarityMeasure similarityMeasure + ) { + return new VoyageAIEmbeddingsModel( + "id", + "service", + url, + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(model, null), + VoyageAIEmbeddingType.FLOAT, + similarityMeasure, + dimensions, + tokenLimit, + false + ), + taskSettings, + null, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..0bedd0d60c25c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java @@ -0,0 +1,327 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettingsTests; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER; +import static org.hamcrest.Matchers.is; + +public class VoyageAIEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { + public static VoyageAIEmbeddingsServiceSettings createRandom() { + SimilarityMeasure similarityMeasure = SimilarityMeasure.DOT_PRODUCT; + Integer dims = 1024; + Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); + Boolean dimensionSetByUser = randomBoolean(); + + var commonSettings = VoyageAIServiceSettingsTests.createRandom(); + + return new VoyageAIEmbeddingsServiceSettings( + commonSettings, + VoyageAIEmbeddingType.FLOAT, + similarityMeasure, + dims, + maxInputTokens, + dimensionSetByUser + ); + } + + public void testFromMap() { + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var dims = 1536; + var maxInputTokens = 512; + var model = "model"; + var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + VoyageAIServiceSettings.MODEL_ID, + model + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(model, null), + VoyageAIEmbeddingType.FLOAT, + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + false + ) + ) + ); + } + + public void testFromMap_WithModelId() { + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var maxInputTokens = 512; + var model = "model"; + var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + VoyageAIServiceSettings.MODEL_ID, + model + ) + ), + ConfigurationParseContext.REQUEST + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(model, null), + VoyageAIEmbeddingType.FLOAT, + SimilarityMeasure.DOT_PRODUCT, + null, + maxInputTokens, + false + ) + ) + ); + } + + public void testFromMap_WithModelId_WithDimensions() { + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var dims = 1536; + var maxInputTokens = 512; + var model = "model"; + var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + VoyageAIServiceSettings.MODEL_ID, + model + ) + ), + ConfigurationParseContext.REQUEST + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(model, null), + VoyageAIEmbeddingType.FLOAT, + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + true + ) + ) + ); + } + + public void testFromMap_DimensionsSetByUserIsFalseInRequestContext() { + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var maxInputTokens = 512; + var model = "model"; + var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + DIMENSIONS_SET_BY_USER, + true, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + VoyageAIServiceSettings.MODEL_ID, + model + ) + ), + ConfigurationParseContext.REQUEST + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(model, null), + VoyageAIEmbeddingType.FLOAT, + SimilarityMeasure.DOT_PRODUCT, + null, + maxInputTokens, + false + ) + ) + ); + } + + public void testFromMap_DimensionsSetByUserIsSetInPersistentContext() { + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var maxInputTokens = 512; + var model = "model"; + var dimensionsSetByUser = randomBoolean(); + var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + DIMENSIONS_SET_BY_USER, + dimensionsSetByUser, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + VoyageAIServiceSettings.MODEL_ID, + model + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(model, null), + VoyageAIEmbeddingType.FLOAT, + SimilarityMeasure.DOT_PRODUCT, + null, + maxInputTokens, + dimensionsSetByUser + ) + ) + ); + } + + public void testFromMap_InvalidSimilarity_ThrowsError() { + var similarity = "by_size"; + var thrownException = expectThrows( + ValidationException.class, + () -> VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, "model", ServiceFields.SIMILARITY, similarity)), + ConfigurationParseContext.PERSISTENT + ) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is( + "Validation Failed: 1: [service_settings] Invalid value [by_size] received. [similarity] " + + "must be one of [cosine, dot_product, l2_norm];" + ) + ); + } + + @SuppressWarnings("checkstyle:LineLength") + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings("model", new RateLimitSettings(3)), + VoyageAIEmbeddingType.FLOAT, + SimilarityMeasure.COSINE, + 5, + 10, + false + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + assertThat( + xContentResult, + is( + """ + {"model_id":"model",""" + + """ + "rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10,"embedding_type":"float"}""" + ) + ); + } + + @SuppressWarnings("checkstyle:LineLength") + public void testToXContent_WritesAllValues_DimensionSetByUser() throws IOException { + var serviceSettings = new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings("model", new RateLimitSettings(3)), + VoyageAIEmbeddingType.FLOAT, + SimilarityMeasure.COSINE, + 5, + 10, + true + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + assertThat( + xContentResult, + is( + """ + {"model_id":"model",""" + + """ + "rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10,"embedding_type":"float"}""" + ) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return VoyageAIEmbeddingsServiceSettings::new; + } + + @Override + protected VoyageAIEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected VoyageAIEmbeddingsServiceSettings mutateInstance(VoyageAIEmbeddingsServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, VoyageAIEmbeddingsServiceSettingsTests::createRandom); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + + public static Map getServiceSettingsMap(String model) { + return new HashMap<>(VoyageAIServiceSettingsTests.getServiceSettingsMap(model)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettingsTests.java new file mode 100644 index 0000000000000..f3d85749e8e29 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettingsTests.java @@ -0,0 +1,218 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithIngestAndSearch; +import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings.VALID_REQUEST_VALUES; +import static org.hamcrest.Matchers.is; + +public class VoyageAIEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase { + + public static VoyageAIEmbeddingsTaskSettings createRandom() { + var inputType = randomBoolean() ? randomWithIngestAndSearch() : null; + var truncation = randomBoolean(); + + return new VoyageAIEmbeddingsTaskSettings(inputType, truncation); + } + + public void testIsEmpty() { + var randomSettings = createRandom(); + var stringRep = Strings.toString(randomSettings); + assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); + } + + public void testUpdatedTaskSettings_NotUpdated_UseInitialSettings() { + var initialSettings = createRandom(); + var newSettings = new VoyageAIEmbeddingsTaskSettings((InputType) null, null); + Map newSettingsMap = new HashMap<>(); + VoyageAIEmbeddingsTaskSettings updatedSettings = (VoyageAIEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( + Collections.unmodifiableMap(newSettingsMap) + ); + assertEquals(initialSettings.getInputType(), updatedSettings.getInputType()); + } + + public void testUpdatedTaskSettings_Updated_UseNewSettings() { + var initialSettings = createRandom(); + var newSettings = new VoyageAIEmbeddingsTaskSettings(randomWithIngestAndSearch(), randomBoolean()); + Map newSettingsMap = new HashMap<>(); + newSettingsMap.put(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, newSettings.getInputType().toString()); + VoyageAIEmbeddingsTaskSettings updatedSettings = (VoyageAIEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( + Collections.unmodifiableMap(newSettingsMap) + ); + assertEquals(newSettings.getInputType(), updatedSettings.getInputType()); + } + + public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() { + MatcherAssert.assertThat( + VoyageAIEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())), + is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null)) + ); + } + + public void testFromMap_CreatesEmptySettings_WhenMapIsNull() { + MatcherAssert.assertThat( + VoyageAIEmbeddingsTaskSettings.fromMap(null), + is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null)) + ); + } + + public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() { + MatcherAssert.assertThat( + VoyageAIEmbeddingsTaskSettings.fromMap( + new HashMap<>( + Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString(), VoyageAIServiceFields.TRUNCATION, false) + ) + ), + is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, false)) + ); + } + + public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() { + var exception = expectThrows( + ValidationException.class, + () -> VoyageAIEmbeddingsTaskSettings.fromMap( + new HashMap<>(Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, "abc", VoyageAIServiceFields.TRUNCATION, false)) + ) + ); + + MatcherAssert.assertThat( + exception.getMessage(), + is( + Strings.format( + "Validation Failed: 1: [task_settings] Invalid value [abc] received. [input_type] must be one of [%s];", + getValidValuesSortedAndCombined(VALID_REQUEST_VALUES) + ) + ) + ); + } + + public void testFromMap_ReturnsFailure_WhenTruncationIsInvalid() { + var exception = expectThrows( + ValidationException.class, + () -> VoyageAIEmbeddingsTaskSettings.fromMap( + new HashMap<>( + Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString(), VoyageAIServiceFields.TRUNCATION, "abc") + ) + ) + ); + + MatcherAssert.assertThat( + exception.getMessage(), + is("Validation Failed: 1: field [truncation] is not of the expected type. The value [abc] cannot be converted to a [Boolean];") + ); + } + + public void testFromMap_ReturnsFailure_WhenInputTypeIsUnspecified() { + var exception = expectThrows( + ValidationException.class, + () -> VoyageAIEmbeddingsTaskSettings.fromMap( + new HashMap<>(Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.UNSPECIFIED.toString())) + ) + ); + + MatcherAssert.assertThat( + exception.getMessage(), + is( + Strings.format( + "Validation Failed: 1: [task_settings] Invalid value [unspecified] received. [input_type] must be one of [%s];", + getValidValuesSortedAndCombined(VALID_REQUEST_VALUES) + ) + ) + ); + } + + private static > String getValidValuesSortedAndCombined(EnumSet validValues) { + var validValuesAsStrings = validValues.stream().map(value -> value.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new); + Arrays.sort(validValuesAsStrings); + + return String.join(", ", validValuesAsStrings); + } + + public void testXContent_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { + var thrownException = expectThrows(AssertionError.class, () -> new VoyageAIEmbeddingsTaskSettings(InputType.UNSPECIFIED, null)); + MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]")); + } + + public void testOf_KeepsOriginalValuesWhenRequestSettingsAreNull_AndRequestInputTypeIsInvalid() { + var taskSettings = new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, false); + var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of( + taskSettings, + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + InputType.UNSPECIFIED + ); + MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings)); + } + + public void testOf_UsesRequestTaskSettings() { + var taskSettings = new VoyageAIEmbeddingsTaskSettings((InputType) null, null); + var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of( + taskSettings, + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true), + InputType.UNSPECIFIED + ); + + MatcherAssert.assertThat(overriddenTaskSettings, is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true))); + } + + public void testOf_UsesRequestTaskSettings_AndRequestInputType() { + var taskSettings = new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, true); + var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of( + taskSettings, + new VoyageAIEmbeddingsTaskSettings((InputType) null, null), + InputType.INGEST + ); + + MatcherAssert.assertThat(overriddenTaskSettings, is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true))); + } + + @Override + protected Writeable.Reader instanceReader() { + return VoyageAIEmbeddingsTaskSettings::new; + } + + @Override + protected VoyageAIEmbeddingsTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected VoyageAIEmbeddingsTaskSettings mutateInstance(VoyageAIEmbeddingsTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, VoyageAIEmbeddingsTaskSettingsTests::createRandom); + } + + public static Map getTaskSettingsMapEmpty() { + return new HashMap<>(); + } + + public static Map getTaskSettingsMap(@Nullable InputType inputType) { + var map = new HashMap(); + + if (inputType != null) { + map.put(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, inputType.toString()); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java new file mode 100644 index 0000000000000..d9ceca107f9e4 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.rerank; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; + +public class VoyageAIRerankModelTests { + public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nullable Integer topK, @Nullable Boolean truncation) { + return new VoyageAIRerankModel( + "id", + "service", + ESTestCase.randomAlphaOfLength(10), + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)), + new VoyageAIRerankTaskSettings(topK, null, truncation), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nullable Integer topK) { + return new VoyageAIRerankModel( + "id", + "service", + ESTestCase.randomAlphaOfLength(10), + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)), + new VoyageAIRerankTaskSettings(topK, null, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topK) { + return new VoyageAIRerankModel( + "id", + "service", + ESTestCase.randomAlphaOfLength(10), + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)), + new VoyageAIRerankTaskSettings(topK, null, null), + new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) + ); + } + + public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topK, Boolean returnDocuments, Boolean truncation) { + return new VoyageAIRerankModel( + "id", + "service", + ESTestCase.randomAlphaOfLength(10), + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)), + new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation), + new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) + ); + } + + public static VoyageAIRerankModel createModel( + String url, + String modelId, + @Nullable Integer topK, + Boolean returnDocuments, + Boolean truncation + ) { + return new VoyageAIRerankModel( + "id", + "service", + url, + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)), + new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation), + new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) + ); + } + + public static VoyageAIRerankModel createModel( + String url, + String apiKey, + String modelId, + @Nullable Integer topK, + Boolean returnDocuments, + Boolean truncation + ) { + return new VoyageAIRerankModel( + "id", + "service", + url, + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)), + new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java new file mode 100644 index 0000000000000..7891d5cc7cca0 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class VoyageAIRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + public static VoyageAIRerankServiceSettings createRandom() { + return new VoyageAIRerankServiceSettings( + new VoyageAIServiceSettings(randomAlphaOfLength(10), RateLimitSettingsTests.createRandom()) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var url = "http://www.abc.com"; + var model = "model"; + + var serviceSettings = new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(model, null)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model_id":"model", + "rate_limit": { + "requests_per_minute": 2000 + } + } + """)); + } + + @Override + protected Writeable.Reader instanceReader() { + return VoyageAIRerankServiceSettings::new; + } + + @Override + protected VoyageAIRerankServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected VoyageAIRerankServiceSettings mutateInstance(VoyageAIRerankServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, VoyageAIRerankServiceSettingsTests::createRandom); + } + + @Override + protected VoyageAIRerankServiceSettings mutateInstanceForVersion(VoyageAIRerankServiceSettings instance, TransportVersion version) { + return instance; + } + + public static Map getServiceSettingsMap(@Nullable String model) { + return new HashMap<>(VoyageAIServiceSettingsTests.getServiceSettingsMap(model)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java new file mode 100644 index 0000000000000..02c8f9ae677ef --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.rerank; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; + +public class VoyageAIRerankTaskSettingsTests extends AbstractWireSerializingTestCase { + + public static VoyageAIRerankTaskSettings createRandom() { + var returnDocuments = randomBoolean() ? randomBoolean() : null; + var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null; + var truncation = randomBoolean() ? randomBoolean() : null; + + return new VoyageAIRerankTaskSettings(topNDocsOnly, returnDocuments, truncation); + } + + public void testFromMap_WithInvalidTruncation_ThrowsValidationException() { + Map taskMap = Map.of( + VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, + true, + VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, + 5, + VoyageAIServiceFields.TRUNCATION, + "invalid" + ); + var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [truncation] is not of the expected type")); + } + + public void testFromMap_WithValidValues_ReturnsSettings() { + Map taskMap = Map.of( + VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, + true, + VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, + 5, + VoyageAIServiceFields.TRUNCATION, + true + ); + var settings = VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap)); + assertTrue(settings.getReturnDocuments()); + assertEquals(5, settings.getTopKDocumentsOnly().intValue()); + assertTrue(settings.getTruncation()); + } + + public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() { + var settings = VoyageAIRerankTaskSettings.fromMap(Map.of()); + assertNull(settings.getReturnDocuments()); + assertNull(settings.getTopKDocumentsOnly()); + assertNull(settings.getTruncation()); + } + + public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() { + Map taskMap = Map.of( + VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, + "invalid", + VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, + 5 + ); + var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type")); + } + + public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { + Map taskMap = Map.of( + VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, + true, + VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, + "invalid" + ); + var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat( + thrownException.getMessage(), + containsString("field [top_k] is not of the expected type. The value [invalid] cannot be converted to a [Integer];") + ); + } + + public void testUpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { + var initialSettings = new VoyageAIRerankTaskSettings(5, true, true); + VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of()); + assertEquals(initialSettings, updatedSettings); + } + + public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() { + var initialSettings = new VoyageAIRerankTaskSettings(5, true, true); + Map newSettings = Map.of(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, false); + VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertTrue(updatedSettings.getTruncation()); + assertEquals(initialSettings.getTopKDocumentsOnly(), updatedSettings.getTopKDocumentsOnly()); + } + + public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() { + var initialSettings = new VoyageAIRerankTaskSettings(5, true, true); + Map newSettings = Map.of(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, 7); + VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertTrue(updatedSettings.getTruncation()); + assertEquals(7, updatedSettings.getTopKDocumentsOnly().intValue()); + assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments()); + } + + public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() { + var initialSettings = new VoyageAIRerankTaskSettings(5, true, true); + Map newSettings = Map.of( + VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, + false, + VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, + 7 + ); + VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertTrue(updatedSettings.getTruncation()); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(7, updatedSettings.getTopKDocumentsOnly().intValue()); + } + + @Override + protected Writeable.Reader instanceReader() { + return VoyageAIRerankTaskSettings::new; + } + + @Override + protected VoyageAIRerankTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected VoyageAIRerankTaskSettings mutateInstance(VoyageAIRerankTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, VoyageAIRerankTaskSettingsTests::createRandom); + } + + public static Map getTaskSettingsMapEmpty() { + return new HashMap<>(); + } + + public static Map getTaskSettingsMap(@Nullable Integer topNDocumentsOnly, Boolean returnDocuments) { + var map = new HashMap(); + + if (topNDocumentsOnly != null) { + map.put(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, topNDocumentsOnly.toString()); + } + + if (returnDocuments != null) { + map.put(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments.toString()); + } + + return map; + } +}