From 0535410fe626ad6e687843bb9803af53a89edf4c Mon Sep 17 00:00:00 2001 From: fzowl Date: Mon, 3 Feb 2025 18:19:01 +0100 Subject: [PATCH 01/20] VoyageAI embeddings and rerank: - embeddings works, tested - initial rerank code What's missing: - unit and integration tests - rerank request/response mapping and verification --- .../xpack/inference/InferencePlugin.java | 2 + .../voyageai/VoyageAIActionCreator.java | 58 +++ .../voyageai/VoyageAIActionVisitor.java | 21 + .../VoyageAIEmbeddingsRequestManager.java | 57 +++ .../http/sender/VoyageAIRequestManager.java | 28 ++ .../sender/VoyageAIRerankRequestManager.java | 56 +++ .../voyageai/VoyageAIEmbeddingsRequest.java | 96 +++++ .../VoyageAIEmbeddingsRequestEntity.java | 81 ++++ .../request/voyageai/VoyageAIRequest.java | 26 ++ .../voyageai/VoyageAIRerankRequest.java | 86 ++++ .../voyageai/VoyageAIRerankRequestEntity.java | 58 +++ .../request/voyageai/VoyageAIUtils.java | 26 ++ .../VoyageAIEmbeddingsResponseEntity.java | 202 ++++++++++ .../voyageai/VoyageAIErrorResponseEntity.java | 46 +++ .../VoyageAIRerankResponseEntity.java | 158 ++++++++ .../external/voyageai/VoyageAIAccount.java | 32 ++ .../voyageai/VoyageAIResponseHandler.java | 62 +++ .../services/voyageai/VoyageAIModel.java | 68 ++++ .../VoyageAIRateLimitServiceSettings.java | 15 + .../services/voyageai/VoyageAIService.java | 372 ++++++++++++++++++ .../voyageai/VoyageAIServiceFields.java | 13 + .../voyageai/VoyageAIServiceSettings.java | 159 ++++++++ .../embeddings/VoyageAIEmbeddingType.java | 121 ++++++ .../embeddings/VoyageAIEmbeddingsModel.java | 102 +++++ .../VoyageAIEmbeddingsServiceSettings.java | 237 +++++++++++ .../VoyageAIEmbeddingsTaskSettings.java | 249 ++++++++++++ .../voyageai/rerank/VoyageAIRerankModel.java | 102 +++++ .../rerank/VoyageAIRerankServiceSettings.java | 113 ++++++ .../rerank/VoyageAIRerankTaskSettings.java | 169 ++++++++ 29 files changed, 2815 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionVisitor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIEmbeddingsRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRerankRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtils.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIAccount.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIRateLimitServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceFields.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 23df62caab430..eadad4c6875c9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -126,6 +126,7 @@ import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.ArrayList; @@ -356,6 +357,7 @@ public List getInferenceServiceFactories() { context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()), context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()), context -> new JinaAIService(httpFactory.get(), serviceComponents.get()), + context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java new file mode 100644 index 0000000000000..e5361e9e45d0c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.voyageai; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIRerankRequestManager; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; + +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; + +/** + * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the voyageai model type. + */ +public class VoyageAIActionCreator implements VoyageAIActionVisitor { + private final Sender sender; + private final ServiceComponents serviceComponents; + + public VoyageAIActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(VoyageAIEmbeddingsModel model, Map taskSettings, InputType inputType) { + var overriddenModel = VoyageAIEmbeddingsModel.of(model, taskSettings, inputType); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( + overriddenModel.getServiceSettings().getCommonSettings().uri(), + "VoyageAI embeddings" + ); + var requestCreator = VoyageAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool()); + return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + } + + @Override + public ExecutableAction create(VoyageAIRerankModel model, Map taskSettings) { + var overriddenModel = VoyageAIRerankModel.of(model, taskSettings); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( + overriddenModel.getServiceSettings().getCommonSettings().uri(), + "VoyageAI rerank" + ); + var requestCreator = VoyageAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool()); + return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionVisitor.java new file mode 100644 index 0000000000000..d6732dba95475 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionVisitor.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.voyageai; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; + +import java.util.Map; + +public interface VoyageAIActionVisitor { + ExecutableAction create(VoyageAIEmbeddingsModel model, Map taskSettings, InputType inputType); + + ExecutableAction create(VoyageAIRerankModel model, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIEmbeddingsRequestManager.java new file mode 100644 index 0000000000000..be186dc54f8f2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIEmbeddingsRequestManager.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class VoyageAIEmbeddingsRequestManager extends VoyageAIRequestManager { + private static final Logger logger = LogManager.getLogger(VoyageAIEmbeddingsRequestManager.class); + private static final ResponseHandler HANDLER = createEmbeddingsHandler(); + + private static ResponseHandler createEmbeddingsHandler() { + return new VoyageAIResponseHandler("voyageai text embedding", VoyageAIEmbeddingsResponseEntity::fromResponse); + } + + public static VoyageAIEmbeddingsRequestManager of(VoyageAIEmbeddingsModel model, ThreadPool threadPool) { + return new VoyageAIEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final VoyageAIEmbeddingsModel model; + + private VoyageAIEmbeddingsRequestManager(VoyageAIEmbeddingsModel model, ThreadPool threadPool) { + super(threadPool, model); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(docsInput, model); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRequestManager.java new file mode 100644 index 0000000000000..819bb74f237d6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRequestManager.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel; + +import java.util.Objects; + +abstract class VoyageAIRequestManager extends BaseRequestManager { + + protected VoyageAIRequestManager(ThreadPool threadPool, VoyageAIModel model) { + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + } + + record RateLimitGrouping(int apiKeyHash) { + public static RateLimitGrouping of(VoyageAIModel model) { + Objects.requireNonNull(model); + + return new RateLimitGrouping(model.apiKey().hashCode()); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRerankRequestManager.java new file mode 100644 index 0000000000000..ca91ad4dda276 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRerankRequestManager.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequest; +import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIRerankResponseEntity; +import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; + +import java.util.Objects; +import java.util.function.Supplier; + +public class VoyageAIRerankRequestManager extends VoyageAIRequestManager { + private static final Logger logger = LogManager.getLogger(VoyageAIRerankRequestManager.class); + private static final ResponseHandler HANDLER = createVoyageAIResponseHandler(); + + private static ResponseHandler createVoyageAIResponseHandler() { + return new VoyageAIResponseHandler("voyageai rerank", (request, response) -> VoyageAIRerankResponseEntity.fromResponse(response)); + } + + public static VoyageAIRerankRequestManager of(VoyageAIRerankModel model, ThreadPool threadPool) { + return new VoyageAIRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final VoyageAIRerankModel model; + + private VoyageAIRerankRequestManager(VoyageAIRerankModel model, ThreadPool threadPool) { + super(threadPool, model); + this.model = model; + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var rerankInput = QueryAndDocsInputs.of(inferenceInputs); + VoyageAIRerankRequest request = new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java new file mode 100644 index 0000000000000..7512bd723d5ad --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +public class VoyageAIEmbeddingsRequest extends VoyageAIRequest { + + private final VoyageAIAccount account; + private final List input; + private final VoyageAIEmbeddingsServiceSettings serviceSettings; + private final VoyageAIEmbeddingsTaskSettings taskSettings; + private final String model; + private final String inferenceEntityId; + + public VoyageAIEmbeddingsRequest(List input, VoyageAIEmbeddingsModel embeddingsModel) { + Objects.requireNonNull(embeddingsModel); + + account = VoyageAIAccount.of(embeddingsModel, VoyageAIEmbeddingsRequest::buildDefaultUri); + this.input = Objects.requireNonNull(input); + serviceSettings = embeddingsModel.getServiceSettings(); + taskSettings = embeddingsModel.getTaskSettings(); + model = embeddingsModel.getServiceSettings().getCommonSettings().modelId(); + inferenceEntityId = embeddingsModel.getInferenceEntityId(); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(account.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new VoyageAIEmbeddingsRequestEntity( + input, + serviceSettings, + taskSettings, + model + )).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + decorateWithAuthHeader(httpPost, account); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public URI getURI() { + return account.uri(); + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } + + public VoyageAIEmbeddingsTaskSettings getTaskSettings() { return taskSettings; } + + public VoyageAIEmbeddingsServiceSettings getServiceSettings() { return serviceSettings; } + + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(VoyageAIUtils.HOST) + .setPathSegments(VoyageAIUtils.VERSION_1, VoyageAIUtils.EMBEDDINGS_PATH) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..bc28efea7452f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.invalidInputTypeMessage; + +public record VoyageAIEmbeddingsRequestEntity( + List input, + VoyageAIEmbeddingsServiceSettings serviceSettings, + VoyageAIEmbeddingsTaskSettings taskSettings, + String model +) implements ToXContentObject { + + private static final String DOCUMENT = "document"; + private static final String QUERY = "query"; + private static final String INPUT_FIELD = "input"; + private static final String MODEL_FIELD = "model"; + public static final String INPUT_TYPE_FIELD = "input_type"; + public static final String TRUNCATION_FIELD = "truncation"; + public static final String OUTPUT_DIMENSION = "output_dimension"; + static final String OUTPUT_DTYPE_FIELD = "output_dtype"; + + public VoyageAIEmbeddingsRequestEntity { + Objects.requireNonNull(input); + Objects.requireNonNull(model); + Objects.requireNonNull(taskSettings); + Objects.requireNonNull(serviceSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INPUT_FIELD, input); + builder.field(MODEL_FIELD, model); + + if (taskSettings.getInputType() != null) { + builder.field(INPUT_TYPE_FIELD, convertToString(taskSettings.getInputType())); + } + + if(taskSettings.getTruncation() != null) { + builder.field(TRUNCATION_FIELD, taskSettings.getTruncation()); + } + + if(taskSettings.getOutputDimension() != null) { + builder.field(OUTPUT_DIMENSION, taskSettings.getOutputDimension()); + } + + if(serviceSettings.getEmbeddingType() != null) { + builder.field(OUTPUT_DTYPE_FIELD, serviceSettings.getEmbeddingType()); + } + + builder.endObject(); + return builder; + } + + static String convertToString(InputType inputType) { + return switch (inputType) { + case INGEST -> DOCUMENT; + case SEARCH -> QUERY; + default -> { + assert false : invalidInputTypeMessage(inputType); + yield null; + } + }; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequest.java new file mode 100644 index 0000000000000..de5dfe4db07e6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequest.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public abstract class VoyageAIRequest implements Request { + + public static void decorateWithAuthHeader(HttpPost request, VoyageAIAccount account) { + request.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + request.setHeader(createAuthBearerHeader(account.apiKey())); + request.setHeader(VoyageAIUtils.createRequestSourceHeader()); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java new file mode 100644 index 0000000000000..d1a29d69a482f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java @@ -0,0 +1,86 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +public class VoyageAIRerankRequest extends VoyageAIRequest { + + private final VoyageAIAccount account; + private final String query; + private final List input; + private final VoyageAIRerankTaskSettings taskSettings; + private final String model; + private final String inferenceEntityId; + + public VoyageAIRerankRequest(String query, List input, VoyageAIRerankModel model) { + Objects.requireNonNull(model); + + this.account = VoyageAIAccount.of(model, VoyageAIRerankRequest::buildDefaultUri); + this.input = Objects.requireNonNull(input); + this.query = Objects.requireNonNull(query); + taskSettings = model.getTaskSettings(); + this.model = model.getServiceSettings().modelId(); + inferenceEntityId = model.getInferenceEntityId(); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(account.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new VoyageAIRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + decorateWithAuthHeader(httpPost, account); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public URI getURI() { + return account.uri(); + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } + + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(VoyageAIUtils.HOST) + .setPathSegments(VoyageAIUtils.VERSION_1, VoyageAIUtils.RERANK_PATH) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java new file mode 100644 index 0000000000000..b960d2cf01e0b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record VoyageAIRerankRequestEntity(String model, String query, List documents, VoyageAIRerankTaskSettings taskSettings) + implements + ToXContentObject { + + private static final String DOCUMENTS_FIELD = "documents"; + private static final String QUERY_FIELD = "query"; + private static final String MODEL_FIELD = "model"; + public static final String TRUNCATION_FIELD = "truncation"; + + public VoyageAIRerankRequestEntity { + Objects.requireNonNull(query); + Objects.requireNonNull(documents); + Objects.requireNonNull(model); + Objects.requireNonNull(taskSettings); + } + + public VoyageAIRerankRequestEntity(String query, List input, VoyageAIRerankTaskSettings taskSettings, String model) { + this(model, query, input, taskSettings != null ? taskSettings : VoyageAIRerankTaskSettings.EMPTY_SETTINGS); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(MODEL_FIELD, model); + builder.field(QUERY_FIELD, query); + builder.field(DOCUMENTS_FIELD, documents); + + if (taskSettings.getTopKDocumentsOnly() != null) { + builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, taskSettings.getTopKDocumentsOnly()); + } + + if(taskSettings.getTruncation() != null) { + builder.field(TRUNCATION_FIELD, taskSettings.getTruncation()); + } + + builder.endObject(); + return builder; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtils.java new file mode 100644 index 0000000000000..130093826ee0f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtils.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.apache.http.Header; +import org.apache.http.message.BasicHeader; + +public class VoyageAIUtils { + public static final String HOST = "api.voyageai.com"; + public static final String VERSION_1 = "v1"; + public static final String EMBEDDINGS_PATH = "embeddings"; + public static final String RERANK_PATH = "rerank"; + public static final String REQUEST_SOURCE_HEADER = "Request-Source"; + public static final String ELASTIC_REQUEST_SOURCE = "unspecified:elasticsearch"; + + public static Header createRequestSourceHeader() { + return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE); + } + + private VoyageAIUtils() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java new file mode 100644 index 0000000000000..38ee83d2956ba --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java @@ -0,0 +1,202 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.inference.external.response.voyageai; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.CheckedFunction; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.XContentUtils; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType; + +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType.toLowerCase; + +public class VoyageAIEmbeddingsResponseEntity { + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in VoyageAI embeddings response"; + private static final Map> EMBEDDING_PARSERS = Map.of( + toLowerCase(VoyageAIEmbeddingType.FLOAT), + VoyageAIEmbeddingsResponseEntity::parseFloatEmbeddingsArray, + toLowerCase(VoyageAIEmbeddingType.INT8), + VoyageAIEmbeddingsResponseEntity::parseByteEmbeddingsArray + ); + + private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes(); + + private static String supportedEmbeddingTypes() { + var validTypes = EMBEDDING_PARSERS.keySet().toArray(String[]::new); + Arrays.sort(validTypes); + return String.join(", ", validTypes); + } + + /** + * Parses the VoyageAI json response. + * For a request like: + * + *
+     *     
+     *        {
+     *          "input": [
+     *            "Sample text 1",
+     *            "Sample text 2"
+     *          ],
+     *          "model": "voyage-3-large"
+     *        }
+     *     
+     * 
+ * + * The response would look like: + * + *
+     * 
+     * {
+     *  "object": "list",
+     *  "data": [
+     *      {
+     *          "object": "embedding",
+     *          "embedding": [
+     *              -0.009327292,
+     *              -0.0028842222,
+     *          ],
+     *          "index": 0
+     *      },
+     *      {
+     *          "object": "embedding",
+     *          "embedding": [ ... ],
+     *          "index": 1
+     *      }
+     *  ],
+     *  "model": "voyage-3-large",
+     *  "usage": {
+     *      "total_tokens": 10
+     *  }
+     * }
+     * 
+     * 
+ */ + public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + VoyageAIEmbeddingType embeddingType = ((VoyageAIEmbeddingsRequest)request).getServiceSettings().getEmbeddingType(); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); + + if(embeddingType == null || embeddingType == VoyageAIEmbeddingType.FLOAT) { + List embeddingList = parseList( + jsonParser, + VoyageAIEmbeddingsResponseEntity::parseEmbeddingObjectFloat + ); + + return new InferenceTextEmbeddingFloatResults(embeddingList); + } else if(embeddingType == VoyageAIEmbeddingType.INT8) { + List embeddingList = parseList( + jsonParser, + VoyageAIEmbeddingsResponseEntity::parseEmbeddingObjectByte + ); + + return new InferenceTextEmbeddingByteResults(embeddingList); + } else { + throw new IllegalArgumentException("Illegal output_dtype value: " + embeddingType); + } + } + } + + private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseEmbeddingObjectFloat(XContentParser parser) + throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); + // parse and discard the rest of the object + consumeUntilObjectEnd(parser); + + return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList); + } + + private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding parseEmbeddingObjectByte(XContentParser parser) + throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingValuesList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseEmbeddingInt8Entry); + // parse and discard the rest of the object + consumeUntilObjectEnd(parser); + + return InferenceTextEmbeddingByteResults.InferenceByteEmbedding.of(embeddingValuesList); + } + + private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException { + var embeddingList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseByteArrayEntry); + + return new InferenceTextEmbeddingByteResults(embeddingList); + } + + private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + List embeddingValuesList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseEmbeddingInt8Entry); + + return InferenceTextEmbeddingByteResults.InferenceByteEmbedding.of(embeddingValuesList); + } + + private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); + var parsedByte = parser.shortValue(); + checkByteBounds(parsedByte); + + return (byte) parsedByte; + } + + private static void checkByteBounds(short value) { + if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { + throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte"); + } + } + + private static InferenceServiceResults parseFloatEmbeddingsArray(XContentParser parser) throws IOException { + var embeddingList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseFloatArrayEntry); + + return new InferenceTextEmbeddingFloatResults(embeddingList); + } + + private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseFloatArrayEntry(XContentParser parser) + throws IOException { + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); + return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList); + } + + private VoyageAIEmbeddingsResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntity.java new file mode 100644 index 0000000000000..41ffaa61fcd26 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntity.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.voyageai; + +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +public class VoyageAIErrorResponseEntity extends ErrorResponse { + + private VoyageAIErrorResponseEntity(String errorMessage) { + super(errorMessage); + } + + /** + * Parse an HTTP response into a VoyageAIErrorResponseEntity + * + * @param response The error response + * @return An error entity if the response is JSON with a `detail` field containing the error message + * or null if the response does not contain the message field + */ + public static ErrorResponse fromResponse(HttpResult response) { + try ( + XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + var responseMap = jsonParser.map(); + var message = (String) responseMap.get("detail"); + if (message != null) { + return new VoyageAIErrorResponseEntity(message); + } + } catch (Exception e) { + // swallow the error + } + + return ErrorResponse.UNDEFINED_ERROR; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java new file mode 100644 index 0000000000000..9e48c4a1362bb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java @@ -0,0 +1,158 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.inference.external.response.voyageai; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField; +import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class VoyageAIRerankResponseEntity { + + private static final Logger logger = LogManager.getLogger(VoyageAIRerankResponseEntity.class); + + /** + * Parses the VoyageAI ranked response. + * + * For a request like: + * "model": "voyage-reranker-v2-base-multilingual", + * "query": "What is the capital of the United States?", + * "top_n": 3, + * "documents": ["Carson City is the capital city of the American state of Nevada.", + * "The Commonwealth of the Northern Mariana ... Its capital is Saipan.", + * "Washington, D.C. (also known as simply Washington or D.C., ... It is a federal district.", + * "Capital punishment (the death penalty) ... As of 2017, capital punishment is legal in 30 of the 50 states."] + *

+ * The response will look like (without whitespace): + * { + * "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703", + * "results": [ + * { + * "document": { + * "text": "Washington, D.C. is the capital of the United States. It is a federal district." + * }, + * "index": 2, + * "relevance_score": 0.98005307 + * }, + * { + * "document": { + * "text": "Capital punishment (the death penalty) As of 2017, capital punishment is legal in 30 of the 50 states." + * }, + * "index": 3, + * "relevance_score": 0.27904198 + * }, + * { + * "document": { + * "text": "Carson City is the capital city of the American state of Nevada." + * }, + * "index": 0, + * "relevance_score": 0.10194652 + * } + * ], + * "usage": {"total_tokens": 15} + * } + * + * @param response the http response from VoyageAI + * @return the parsed response + * @throws IOException if there is an error parsing the response + */ + public static InferenceServiceResults fromResponse(HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE); + + token = jsonParser.currentToken(); + if (token == XContentParser.Token.START_ARRAY) { + return new RankedDocsResults(parseList(jsonParser, VoyageAIRerankResponseEntity::parseRankedDocObject)); + } else { + throwUnknownToken(token, jsonParser); + } + + // This should never be reached. The above code should either return successfully or hit the throwUnknownToken + // or throw a parsing exception + throw new IllegalStateException("Reached an invalid state while parsing the VoyageAI response"); + } + } + + private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + int index = -1; + float relevanceScore = -1; + String documentText = null; + parser.nextToken(); + while (parser.currentToken() != XContentParser.Token.END_OBJECT) { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { + switch (parser.currentName()) { + case "index": + parser.nextToken(); // move to VALUE_NUMBER + index = parser.intValue(); + parser.nextToken(); // move to next FIELD_NAME or END_OBJECT + break; + case "relevance_score": + parser.nextToken(); // move to VALUE_NUMBER + relevanceScore = parser.floatValue(); + parser.nextToken(); // move to next FIELD_NAME or END_OBJECT + break; + case "document": + parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + do { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) { + parser.nextToken(); // move to VALUE_STRING + documentText = parser.text(); + } + } while (parser.nextToken() != XContentParser.Token.END_OBJECT); + parser.nextToken();// move past END_OBJECT + // parser should now be at the next FIELD_NAME or END_OBJECT + break; + default: + throwUnknownField(parser.currentName(), parser); + } + } else { + parser.nextToken(); + } + } + + if (index == -1) { + logger.warn("Failed to find required field [index] in VoyageAI rerank response"); + } + if (relevanceScore == -1) { + logger.warn("Failed to find required field [relevance_score] in VoyageAI rerank response"); + } + // documentText may or may not be present depending on the request parameter + + return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText); + } + + private VoyageAIRerankResponseEntity() {} + + static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in VoyageAI rerank response"; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIAccount.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIAccount.java new file mode 100644 index 0000000000000..eda5b038b3c3e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIAccount.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.voyageai; + +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; + +public record VoyageAIAccount(URI uri, SecureString apiKey) { + + public static VoyageAIAccount of(VoyageAIModel model, CheckedSupplier uriBuilder) { + var uri = buildUri(model.uri(), "VoyageAI", uriBuilder); + + return new VoyageAIAccount(uri, model.apiKey()); + } + + public VoyageAIAccount { + Objects.requireNonNull(uri); + Objects.requireNonNull(apiKey); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandler.java new file mode 100644 index 0000000000000..1611426034a52 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandler.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.voyageai; + +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIErrorResponseEntity; + +/** + * Defines how to handle various errors returned from the VoyageAI integration. + * + */ +public class VoyageAIResponseHandler extends BaseResponseHandler { + static final String VALIDATION_ERROR_MESSAGE = "Received an input validation error response"; + static final String PAYMENT_ERROR_MESSAGE = "Payment required"; + + public VoyageAIResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, VoyageAIErrorResponseEntity::fromResponse); + } + + /** + * Validates the status code throws an RetryException if not in the range [200, 300). + * + * @param request The http request + * @param result The http response and body + * @throws RetryException Throws if status code is {@code >= 300 or < 200 } + */ + @Override + protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { + if (result.isSuccessfulResponse()) { + return; + } + + // handle error codes + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode == 500) { + throw new RetryException(true, buildError(SERVER_ERROR, request, result)); + } else if (statusCode > 500) { + throw new RetryException(false, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 429) { + throw new RetryException(true, buildError(RATE_LIMIT, request, result)); + } else if (statusCode == 400 || statusCode == 422) { + throw new RetryException(false, buildError(VALIDATION_ERROR_MESSAGE, request, result)); + } else if (statusCode == 401) { + throw new RetryException(false, buildError(AUTHENTICATION, request, result)); + } else if (statusCode == 402) { + throw new RetryException(false, buildError(PAYMENT_ERROR_MESSAGE, request, result)); + } else if (statusCode >= 300 && statusCode < 400) { + throw new RetryException(false, buildError(REDIRECTION, request, result)); + } else { + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIModel.java new file mode 100644 index 0000000000000..c8b953c1b8f97 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIModel.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionVisitor; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; + +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +public abstract class VoyageAIModel extends Model { + private final SecureString apiKey; + private final VoyageAIRateLimitServiceSettings rateLimitServiceSettings; + + public VoyageAIModel( + ModelConfigurations configurations, + ModelSecrets secrets, + @Nullable ApiKeySecrets apiKeySecrets, + VoyageAIRateLimitServiceSettings rateLimitServiceSettings + ) { + super(configurations, secrets); + + this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); + apiKey = ServiceUtils.apiKey(apiKeySecrets); + } + + protected VoyageAIModel(VoyageAIModel model, TaskSettings taskSettings) { + super(model, taskSettings); + + rateLimitServiceSettings = model.rateLimitServiceSettings(); + apiKey = model.apiKey(); + } + + protected VoyageAIModel(VoyageAIModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + + rateLimitServiceSettings = model.rateLimitServiceSettings(); + apiKey = model.apiKey(); + } + + public SecureString apiKey() { + return apiKey; + } + + public VoyageAIRateLimitServiceSettings rateLimitServiceSettings() { + return rateLimitServiceSettings; + } + + public abstract ExecutableAction accept(VoyageAIActionVisitor creator, Map taskSettings, InputType inputType); + + public abstract URI uri(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIRateLimitServiceSettings.java new file mode 100644 index 0000000000000..a4b325fe2db41 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIRateLimitServiceSettings.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai; + +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public interface VoyageAIRateLimitServiceSettings { + RateLimitSettings rateLimitSettings(); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java new file mode 100644 index 0000000000000..f9264d12877e3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -0,0 +1,372 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; + +public class VoyageAIService extends SenderService { + public static final String NAME = "voyageai"; + + private static final String SERVICE_NAME = "Voyage AI"; + private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK); + + private static final int DEFAULT_VOYAGE_2_BATCH_SIZE = 72; + private static final int DEFAULT_VOYAGE_3_LITE_BATCH_SIZE = 30; + private static final int DEFAULT_VOYAGE_3_BATCH_SIZE = 10; + private static final int DEFAULT_BATCH_SIZE = 7; + + public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + VoyageAIModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + private static VoyageAIModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + @Nullable Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + chunkingSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + private static VoyageAIModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + return switch (taskType) { + case TEXT_EMBEDDING -> new VoyageAIEmbeddingsModel( + inferenceEntityId, + NAME, + serviceSettings, + taskSettings, + chunkingSettings, + secretSettings, + context + ); + case RERANK -> new VoyageAIRerankModel(inferenceEntityId, NAME, serviceSettings, taskSettings, secretSettings, context); + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + }; + } + + @Override + public VoyageAIModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + secretSettingsMap, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public VoyageAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return supportedTaskTypes; + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + + @Override + public void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof VoyageAIModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + VoyageAIModel voyageaiModel = (VoyageAIModel) model; + var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents()); + + var action = voyageaiModel.accept(actionCreator, taskSettings, inputType); + action.execute(inputs, timeout, listener); + } + + @Override + protected void doChunkedInfer( + Model model, + DocumentsOnlyInput inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + if (model instanceof VoyageAIModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + VoyageAIModel voyageaiModel = (VoyageAIModel) model; + var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents()); + + List batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + getBatchSize(voyageaiModel), + EmbeddingRequestChunker.EmbeddingType.fromDenseVectorElementType(model.getServiceSettings().elementType()), + voyageaiModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = voyageaiModel.accept(actionCreator, taskSettings, inputType); + action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); + } + } + + private static int getBatchSize(VoyageAIModel model) { + int maxBatchSize = DEFAULT_BATCH_SIZE; + + if ("voyage-2".equals(model.getServiceSettings().modelId()) || "voyage-02".equals(model.getServiceSettings().modelId())) { + maxBatchSize = DEFAULT_VOYAGE_2_BATCH_SIZE; + } else if ("voyage-3-lite".equals(model.getServiceSettings().modelId())) { + maxBatchSize = DEFAULT_VOYAGE_3_LITE_BATCH_SIZE; + } else if ("voyage-3".equals(model.getServiceSettings().modelId())) { + maxBatchSize = DEFAULT_VOYAGE_3_BATCH_SIZE; + } + + return maxBatchSize; + } + + /** + * For text embedding models get the embedding size and + * update the service settings. + * + * @param model The new model + * @param listener The listener + */ + @Override + public void checkModelConfig(Model model, ActionListener listener) { + ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); + } + + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof VoyageAIEmbeddingsModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel; + var maxInputTokens = serviceSettings.maxInputTokens(); + + var updatedServiceSettings = new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings( + serviceSettings.getCommonSettings().uri(), + serviceSettings.getCommonSettings().modelId(), + serviceSettings.getCommonSettings().rateLimitSettings() + ), + serviceSettings.getEmbeddingType(), + similarityToUse, + embeddingSize, + maxInputTokens + ); + + return new VoyageAIEmbeddingsModel(embeddingsModel, updatedServiceSettings); + } else { + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } + } + + /** + * Return the default similarity measure for the embedding type. + * VoyageAI embeddings are normalized to unit vectors therefore Dot + * Product similarity can be used and is the default for all VoyageAI + * models. + * + * @return The default similarity. + */ + static SimilarityMeasure defaultSimilarity() { + return SimilarityMeasure.DOT_PRODUCT; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.JINA_AI_INTEGRATION_ADDED; + } + + public static class Configuration { + public static InferenceServiceConfiguration get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable configuration = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(supportedTaskTypes) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceFields.java new file mode 100644 index 0000000000000..6d6e84f2e0c11 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceFields.java @@ -0,0 +1,13 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai; + +public class VoyageAIServiceFields { + public static final String TRUNCATION = "truncation"; + public static final String OUTPUT_DIMENSION = "output_dimension"; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java new file mode 100644 index 0000000000000..ea4efa7afabcf --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java @@ -0,0 +1,159 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +public class VoyageAIServiceSettings extends FilteredXContentObject implements ServiceSettings, VoyageAIRateLimitServiceSettings { + + public static final String NAME = "voyageai_service_settings"; + public static final String MODEL_ID = "model_id"; + private static final Logger logger = LogManager.getLogger(VoyageAIServiceSettings.class); + // See https://jina.ai/contact-sales/#rate-limit + public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(2_000); + + public static VoyageAIServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + VoyageAIService.NAME, + context + ); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new VoyageAIServiceSettings(uri, modelId, rateLimitSettings); + } + + private final URI uri; + private final String modelId; + private final RateLimitSettings rateLimitSettings; + + public VoyageAIServiceSettings(@Nullable URI uri, String modelId, @Nullable RateLimitSettings rateLimitSettings) { + this.uri = uri; + this.modelId = Objects.requireNonNull(modelId); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public VoyageAIServiceSettings(@Nullable String url, String modelId, @Nullable RateLimitSettings rateLimitSettings) { + this(createOptionalUri(url), modelId, rateLimitSettings); + } + + public VoyageAIServiceSettings(StreamInput in) throws IOException { + uri = createOptionalUri(in.readOptionalString()); + modelId = in.readOptionalString(); + rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + public URI uri() { + return uri; + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragment(builder, params); + + builder.endObject(); + return builder; + } + + public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException { + return toXContentFragmentOfExposedFields(builder, params); + } + + @Override + public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + if (uri != null) { + builder.field(URL, uri.toString()); + } + if (modelId != null) { + builder.field(MODEL_ID, modelId); + } + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.JINA_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + var uriToWrite = uri != null ? uri.toString() : null; + out.writeOptionalString(uriToWrite); + out.writeOptionalString(modelId); + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VoyageAIServiceSettings that = (VoyageAIServiceSettings) o; + return Objects.equals(uri, that.uri) + && Objects.equals(modelId, that.modelId) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(uri, modelId, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java new file mode 100644 index 0000000000000..ee414b89bfe5c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; + +import java.util.Arrays; +import java.util.EnumSet; +import java.util.Locale; +import java.util.Map; + +/** + * Defines the type of embedding that the cohere api should return for a request. + * + *

+ * See api docs for details. + *

+ */ +public enum VoyageAIEmbeddingType { + /** + * Use this when you want to get back the default float embeddings. Valid for all models. + */ + FLOAT(DenseVectorFieldMapper.ElementType.FLOAT, RequestConstants.FLOAT), + /** + * Use this when you want to get back signed int8 embeddings. Valid for only v3 models. + */ + INT8(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8), + /** + * This is a synonym for INT8 + */ + BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8); + + private static final class RequestConstants { + private static final String FLOAT = "float"; + private static final String INT8 = "int8"; + } + + private static final Map ELEMENT_TYPE_TO_VOYAGE_EMBEDDING = Map.of( + DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, + DenseVectorFieldMapper.ElementType.BYTE, + BYTE + ); + static final EnumSet SUPPORTED_ELEMENT_TYPES = EnumSet.copyOf( + ELEMENT_TYPE_TO_VOYAGE_EMBEDDING.keySet() + ); + + private final DenseVectorFieldMapper.ElementType elementType; + private final String requestString; + + VoyageAIEmbeddingType(DenseVectorFieldMapper.ElementType elementType, String requestString) { + this.elementType = elementType; + this.requestString = requestString; + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + + public String toRequestString() { + return requestString; + } + + public static String toLowerCase(VoyageAIEmbeddingType type) { + return type.toString().toLowerCase(Locale.ROOT); + } + + public static VoyageAIEmbeddingType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static VoyageAIEmbeddingType fromElementType(DenseVectorFieldMapper.ElementType elementType) { + var embedding = ELEMENT_TYPE_TO_VOYAGE_EMBEDDING.get(elementType); + + if (embedding == null) { + var validElementTypes = SUPPORTED_ELEMENT_TYPES.stream() + .map(value -> value.toString().toLowerCase(Locale.ROOT)) + .toArray(String[]::new); + Arrays.sort(validElementTypes); + + throw new IllegalArgumentException( + Strings.format( + "Element type [%s] does not map to a Cohere embedding value, must be one of [%s]", + elementType, + String.join(", ", validElementTypes) + ) + ); + } + + return embedding; + } + + public DenseVectorFieldMapper.ElementType toElementType() { + return elementType; + } + + /** + * Returns an embedding type that is known based on the transport version provided. If the embedding type enum was not yet + * introduced it will be defaulted INT8. + * + * @param embeddingType the value to translate if necessary + * @param version the version that dictates the translation + * @return the embedding type that is known to the version passed in + */ + public static VoyageAIEmbeddingType translateToVersion(VoyageAIEmbeddingType embeddingType, TransportVersion version) { + if (version.before(TransportVersions.V_8_14_0) && embeddingType == BYTE) { + return INT8; + } + + return embeddingType; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModel.java new file mode 100644 index 0000000000000..b20142ca32970 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModel.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.embeddings; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel; + +import java.net.URI; +import java.util.Map; + +public class VoyageAIEmbeddingsModel extends VoyageAIModel { + public static VoyageAIEmbeddingsModel of(VoyageAIEmbeddingsModel model, Map taskSettings, InputType inputType) { + var requestTaskSettings = VoyageAIEmbeddingsTaskSettings.fromMap(taskSettings); + return new VoyageAIEmbeddingsModel( + model, + VoyageAIEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings, inputType) + ); + } + + public VoyageAIEmbeddingsModel( + String inferenceId, + String service, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceId, + service, + VoyageAIEmbeddingsServiceSettings.fromMap(serviceSettings, context), + VoyageAIEmbeddingsTaskSettings.fromMap(taskSettings), + chunkingSettings, + DefaultSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + VoyageAIEmbeddingsModel( + String modelId, + String service, + VoyageAIEmbeddingsServiceSettings serviceSettings, + VoyageAIEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(modelId, TaskType.TEXT_EMBEDDING, service, serviceSettings, taskSettings, chunkingSettings), + new ModelSecrets(secretSettings), + secretSettings, + serviceSettings.getCommonSettings() + ); + } + + private VoyageAIEmbeddingsModel(VoyageAIEmbeddingsModel model, VoyageAIEmbeddingsTaskSettings taskSettings) { + super(model, taskSettings); + } + + public VoyageAIEmbeddingsModel(VoyageAIEmbeddingsModel model, VoyageAIEmbeddingsServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + @Override + public VoyageAIEmbeddingsServiceSettings getServiceSettings() { + return (VoyageAIEmbeddingsServiceSettings) super.getServiceSettings(); + } + + @Override + public VoyageAIEmbeddingsTaskSettings getTaskSettings() { + return (VoyageAIEmbeddingsTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + @Override + public ExecutableAction accept(VoyageAIActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings, inputType); + } + + @Override + public URI uri() { + return getServiceSettings().getCommonSettings().uri(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..f3b5b0a77f13a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java @@ -0,0 +1,237 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; + +public class VoyageAIEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { + public static final String NAME = "voyageai_embeddings_service_settings"; + + static final String EMBEDDING_TYPE = "embedding_type"; + + public static VoyageAIEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + var commonServiceSettings = VoyageAIServiceSettings.fromMap(map, context); + + VoyageAIEmbeddingType embeddingTypes = parseEmbeddingType(map, context, validationException); + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new VoyageAIEmbeddingsServiceSettings(commonServiceSettings, embeddingTypes, similarity, dims, maxInputTokens); + } + + static VoyageAIEmbeddingType parseEmbeddingType( + Map map, + ConfigurationParseContext context, + ValidationException validationException + ) { + return switch (context) { + case REQUEST -> Objects.requireNonNullElse( + extractOptionalEnum( + map, + EMBEDDING_TYPE, + ModelConfigurations.SERVICE_SETTINGS, + VoyageAIEmbeddingType::fromString, + EnumSet.allOf(VoyageAIEmbeddingType.class), + validationException + ), + VoyageAIEmbeddingType.FLOAT + ); + case PERSISTENT -> { + var embeddingType = ServiceUtils.extractOptionalString( + map, + EMBEDDING_TYPE, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + yield fromVoyageAIOrDenseVectorEnumValues(embeddingType, validationException); + } + + }; + } + + static VoyageAIEmbeddingType fromVoyageAIOrDenseVectorEnumValues(String enumString, ValidationException validationException) { + if (enumString == null) { + return VoyageAIEmbeddingType.FLOAT; + } + + try { + return VoyageAIEmbeddingType.fromString(enumString); + } catch (IllegalArgumentException ae) { + try { + return VoyageAIEmbeddingType.fromElementType(DenseVectorFieldMapper.ElementType.fromString(enumString)); + } catch (IllegalArgumentException iae) { + var validValuesAsStrings = VoyageAIEmbeddingType.SUPPORTED_ELEMENT_TYPES.stream() + .map(value -> value.toString().toLowerCase(Locale.ROOT)) + .toArray(String[]::new); + validationException.addValidationError( + ServiceUtils.invalidValue(EMBEDDING_TYPE, ModelConfigurations.SERVICE_SETTINGS, enumString, validValuesAsStrings) + ); + return null; + } + } + } + + private final VoyageAIServiceSettings commonSettings; + private final VoyageAIEmbeddingType embeddingType; + private final SimilarityMeasure similarity; + private final Integer dimensions; + private final Integer maxInputTokens; + + public VoyageAIEmbeddingsServiceSettings( + VoyageAIServiceSettings commonSettings, + VoyageAIEmbeddingType embeddingType, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens + ) { + this.commonSettings = commonSettings; + this.similarity = similarity; + this.dimensions = dimensions; + this.maxInputTokens = maxInputTokens; + this.embeddingType = Objects.requireNonNull(embeddingType); + } + + public VoyageAIEmbeddingsServiceSettings(StreamInput in) throws IOException { + this.commonSettings = new VoyageAIServiceSettings(in); + this.embeddingType = Objects.requireNonNullElse(in.readOptionalEnum(VoyageAIEmbeddingType.class), VoyageAIEmbeddingType.FLOAT); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.dimensions = in.readOptionalVInt(); + this.maxInputTokens = in.readOptionalVInt(); + } + + public VoyageAIServiceSettings getCommonSettings() { + return commonSettings; + } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public Integer dimensions() { + return dimensions; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public String modelId() { + return commonSettings.modelId(); + } + + public VoyageAIEmbeddingType getEmbeddingType() { + return embeddingType; + } + + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return embeddingType == null ? DenseVectorFieldMapper.ElementType.FLOAT : embeddingType.toElementType(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder = commonSettings.toXContentFragment(builder, params); + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + if (embeddingType != null) { + builder.field(EMBEDDING_TYPE, embeddingType); + } + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + commonSettings.toXContentFragmentOfExposedFields(builder, params); + + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + commonSettings.writeTo(out); + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + out.writeOptionalVInt(dimensions); + out.writeOptionalVInt(maxInputTokens); + out.writeOptionalEnum(embeddingType); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VoyageAIEmbeddingsServiceSettings that = (VoyageAIEmbeddingsServiceSettings) o; + return Objects.equals(commonSettings, that.commonSettings) + && Objects.equals(similarity, that.similarity) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(embeddingType, that.embeddingType); + } + + @Override + public int hashCode() { + return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens, embeddingType); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java new file mode 100644 index 0000000000000..f9cbb9a41099c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java @@ -0,0 +1,249 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields.OUTPUT_DIMENSION; +import static org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields.TRUNCATION; + +/** + * Defines the task settings for the voyageai text embeddings service. + * + *

+ * See api docs for details. + *

+ */ +public class VoyageAIEmbeddingsTaskSettings implements TaskSettings { + + public static final String NAME = "voyageai_embeddings_task_settings"; + public static final VoyageAIEmbeddingsTaskSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsTaskSettings(null, null, null); + static final String INPUT_TYPE = "input_type"; + static final EnumSet VALID_REQUEST_VALUES = EnumSet.of( + InputType.INGEST, + InputType.SEARCH + ); + + public static VoyageAIEmbeddingsTaskSettings fromMap(Map map) { + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + ValidationException validationException = new ValidationException(); + + InputType inputType = extractOptionalEnum( + map, + INPUT_TYPE, + ModelConfigurations.TASK_SETTINGS, + InputType::fromString, + VALID_REQUEST_VALUES, + validationException + ); + Boolean truncation = extractOptionalBoolean( + map, + TRUNCATION, + validationException + ); + Integer outputDimension = extractOptionalPositiveInteger( + map, + OUTPUT_DIMENSION, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new VoyageAIEmbeddingsTaskSettings(inputType, truncation, outputDimension); + } + + /** + * Creates a new {@link VoyageAIEmbeddingsTaskSettings} by preferring non-null fields from the provided parameters. + * For the input type, preference is given to requestInputType if it is not null and not UNSPECIFIED. + * Then preference is given to the requestTaskSettings and finally to originalSettings even if the value is null. + * + * Similarly, for the truncation field preference is given to requestTaskSettings if it is not null and then to + * originalSettings. + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @param requestInputType the input type passed in the request parameters + * @return a constructed {@link VoyageAIEmbeddingsTaskSettings} + */ + public static VoyageAIEmbeddingsTaskSettings of( + VoyageAIEmbeddingsTaskSettings originalSettings, + VoyageAIEmbeddingsTaskSettings requestTaskSettings, + InputType requestInputType + ) { + var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings, requestInputType); + var truncationToUse = getValidTruncation(originalSettings, requestTaskSettings); + var outputDimension = getValidOutputDimension(originalSettings, requestTaskSettings); + + return new VoyageAIEmbeddingsTaskSettings(inputTypeToUse, truncationToUse, outputDimension); + } + + private static InputType getValidInputType( + VoyageAIEmbeddingsTaskSettings originalSettings, + VoyageAIEmbeddingsTaskSettings requestTaskSettings, + InputType requestInputType + ) { + InputType inputTypeToUse = originalSettings.inputType; + + if (VALID_REQUEST_VALUES.contains(requestInputType)) { + inputTypeToUse = requestInputType; + } else if (requestTaskSettings.inputType != null) { + inputTypeToUse = requestTaskSettings.inputType; + } + + return inputTypeToUse; + } + + private static Boolean getValidTruncation( + VoyageAIEmbeddingsTaskSettings originalSettings, + VoyageAIEmbeddingsTaskSettings requestTaskSettings + ) { + return requestTaskSettings.getTruncation() == null ? originalSettings.truncation : requestTaskSettings.getTruncation(); + } + + private static Integer getValidOutputDimension( + VoyageAIEmbeddingsTaskSettings originalSettings, + VoyageAIEmbeddingsTaskSettings requestTaskSettings + ) { + return requestTaskSettings.getOutputDimension() == null + ? originalSettings.outputDimension + : requestTaskSettings.getOutputDimension(); + } + + private final InputType inputType; + private final Boolean truncation; + private final Integer outputDimension; + + public VoyageAIEmbeddingsTaskSettings(StreamInput in) throws IOException { + this( + in.readOptionalEnum(InputType.class), + in.readOptionalBoolean(), + in.readOptionalInt() + ); + } + + public VoyageAIEmbeddingsTaskSettings( + @Nullable InputType inputType, + @Nullable Boolean truncation, + @Nullable Integer outputDimension + ) { + validateInputType(inputType); + this.inputType = inputType; + this.truncation = truncation; + this.outputDimension = outputDimension; + } + + private static void validateInputType(InputType inputType) { + if (inputType == null) { + return; + } + + assert VALID_REQUEST_VALUES.contains(inputType) : invalidInputTypeMessage(inputType); + } + + @Override + public boolean isEmpty() { + return inputType == null && truncation == null && outputDimension == null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (inputType != null) { + builder.field(INPUT_TYPE, inputType); + } + + if (truncation != null) { + builder.field(TRUNCATION, truncation); + } + + if (outputDimension != null) { + builder.field(OUTPUT_DIMENSION, outputDimension); + } + + builder.endObject(); + return builder; + } + + public InputType getInputType() { + return inputType; + } + + public Boolean getTruncation() { + return truncation; + } + + public Integer getOutputDimension() { + return outputDimension; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalEnum(inputType); + out.writeOptionalBoolean(truncation); + out.writeOptionalInt(outputDimension); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VoyageAIEmbeddingsTaskSettings that = (VoyageAIEmbeddingsTaskSettings) o; + return Objects.equals(inputType, that.inputType) && + Objects.equals(truncation, that.truncation) && + Objects.equals(outputDimension, that.outputDimension); + } + + @Override + public int hashCode() { + return Objects.hash(inputType, truncation, outputDimension); + } + + public static String invalidInputTypeMessage(InputType inputType) { + return Strings.format("received invalid input type value [%s]", inputType.toString()); + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + VoyageAIEmbeddingsTaskSettings updatedSettings = VoyageAIEmbeddingsTaskSettings.fromMap(new HashMap<>(newSettings)); + return of(this, updatedSettings, updatedSettings.inputType != null ? updatedSettings.inputType : this.inputType); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java new file mode 100644 index 0000000000000..fd6c0ee6c5002 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel; + +import java.net.URI; +import java.util.Map; + +public class VoyageAIRerankModel extends VoyageAIModel { + public static VoyageAIRerankModel of(VoyageAIRerankModel model, Map taskSettings) { + var requestTaskSettings = VoyageAIRerankTaskSettings.fromMap(taskSettings); + return new VoyageAIRerankModel(model, VoyageAIRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + + public VoyageAIRerankModel( + String inferenceId, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceId, + service, + VoyageAIRerankServiceSettings.fromMap(serviceSettings, context), + VoyageAIRerankTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + VoyageAIRerankModel( + String modelId, + String service, + VoyageAIRerankServiceSettings serviceSettings, + VoyageAIRerankTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(modelId, TaskType.RERANK, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + secretSettings, + serviceSettings.getCommonSettings() + ); + } + + private VoyageAIRerankModel(VoyageAIRerankModel model, VoyageAIRerankTaskSettings taskSettings) { + super(model, taskSettings); + } + + public VoyageAIRerankModel(VoyageAIRerankModel model, VoyageAIRerankServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + @Override + public VoyageAIRerankServiceSettings getServiceSettings() { + return (VoyageAIRerankServiceSettings) super.getServiceSettings(); + } + + @Override + public VoyageAIRerankTaskSettings getTaskSettings() { + return (VoyageAIRerankTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + /** + * Accepts a visitor to create an executable action. The returned action will not return documents in the response. + * @param visitor _ + * @param taskSettings _ + * @param inputType ignored for rerank task + * @return the rerank action + */ + @Override + public ExecutableAction accept(VoyageAIActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings); + } + + @Override + public URI uri() { + return getServiceSettings().getCommonSettings().uri(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettings.java new file mode 100644 index 0000000000000..916236aabd246 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettings.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.rerank; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class VoyageAIRerankServiceSettings extends FilteredXContentObject implements ServiceSettings, VoyageAIRateLimitServiceSettings { + public static final String NAME = "voyageai_rerank_service_settings"; + + private static final Logger logger = LogManager.getLogger(VoyageAIRerankServiceSettings.class); + + public static VoyageAIRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + var commonServiceSettings = VoyageAIServiceSettings.fromMap(map, context); + + return new VoyageAIRerankServiceSettings(commonServiceSettings); + } + + private final VoyageAIServiceSettings commonSettings; + + public VoyageAIRerankServiceSettings(VoyageAIServiceSettings commonSettings) { + this.commonSettings = commonSettings; + } + + public VoyageAIRerankServiceSettings(StreamInput in) throws IOException { + this.commonSettings = new VoyageAIServiceSettings(in); + } + + public VoyageAIServiceSettings getCommonSettings() { + return commonSettings; + } + + @Override + public String modelId() { + return commonSettings.modelId(); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return commonSettings.rateLimitSettings(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder = commonSettings.toXContentFragment(builder, params); + + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + commonSettings.toXContentFragmentOfExposedFields(builder, params); + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.JINA_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + commonSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VoyageAIRerankServiceSettings that = (VoyageAIRerankServiceSettings) o; + return Objects.equals(commonSettings, that.commonSettings); + } + + @Override + public int hashCode() { + return Objects.hash(commonSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java new file mode 100644 index 0000000000000..db5777065a160 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java @@ -0,0 +1,169 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields.TRUNCATION; + +/** + * Defines the task settings for the VoyageAI rerank service. + * + */ +public class VoyageAIRerankTaskSettings implements TaskSettings { + + public static final String NAME = "voyageai_rerank_task_settings"; + public static final String TOP_K_DOCS_ONLY = "top_k"; + + public static final VoyageAIRerankTaskSettings EMPTY_SETTINGS = new VoyageAIRerankTaskSettings(null, null); + + public static VoyageAIRerankTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + Integer topKDocumentsOnly = extractOptionalPositiveInteger( + map, + TOP_K_DOCS_ONLY, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + + Boolean truncation = extractOptionalBoolean( + map, + TRUNCATION, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return of(topKDocumentsOnly, truncation); + } + + /** + * Creates a new {@link VoyageAIRerankTaskSettings} by preferring non-null fields from the request settings over the original settings. + * + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @return a constructed {@link VoyageAIRerankTaskSettings} + */ + public static VoyageAIRerankTaskSettings of( + VoyageAIRerankTaskSettings originalSettings, + VoyageAIRerankTaskSettings requestTaskSettings + ) { + return new VoyageAIRerankTaskSettings( + requestTaskSettings.getTopKDocumentsOnly() != null + ? requestTaskSettings.getTopKDocumentsOnly() + : originalSettings.getTopKDocumentsOnly(), + requestTaskSettings.getTruncation() != null ? requestTaskSettings.getTruncation() : originalSettings.getTruncation() + + ); + } + + public static VoyageAIRerankTaskSettings of(Integer topKDocumentsOnly, Boolean truncation) { + return new VoyageAIRerankTaskSettings(topKDocumentsOnly, truncation); + } + + private final Integer topKDocumentsOnly; + private final Boolean truncation; + + public VoyageAIRerankTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalInt(), in.readOptionalBoolean()); + } + + public VoyageAIRerankTaskSettings(@Nullable Integer topKDocumentsOnly, @Nullable Boolean truncation) { + this.topKDocumentsOnly = topKDocumentsOnly; + this.truncation = truncation; + } + + @Override + public boolean isEmpty() { + return topKDocumentsOnly == null && truncation == null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (topKDocumentsOnly != null) { + builder.field(TOP_K_DOCS_ONLY, topKDocumentsOnly); + } + if (truncation != null) { + builder.field(TRUNCATION, truncation); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalInt(topKDocumentsOnly); + out.writeOptionalBoolean(truncation); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VoyageAIRerankTaskSettings that = (VoyageAIRerankTaskSettings) o; + return Objects.equals(topKDocumentsOnly, that.topKDocumentsOnly) && Objects.equals(truncation, that.truncation); + } + + @Override + public int hashCode() { + return Objects.hash(truncation, topKDocumentsOnly); + } + + public static String invalidInputTypeMessage(InputType inputType) { + return Strings.format("received invalid input type value [%s]", inputType.toString()); + } + + public Integer getTopKDocumentsOnly() { + return topKDocumentsOnly; + } + + public Boolean getTruncation() { + return truncation; + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + VoyageAIRerankTaskSettings updatedSettings = VoyageAIRerankTaskSettings.fromMap(new HashMap<>(newSettings)); + return VoyageAIRerankTaskSettings.of(this, updatedSettings); + } +} From 07c39a00c0c5ea50d08034596eabb28c3b8c5945 Mon Sep 17 00:00:00 2001 From: fzowl Date: Mon, 3 Feb 2025 22:10:43 +0100 Subject: [PATCH 02/20] VoyageAI embeddings and rerank: - embeddings works, tested - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank) What's missing: - unit and integration tests --- .../voyageai/VoyageAIRerankRequestEntity.java | 5 ++ .../VoyageAIRerankResponseEntity.java | 56 +++++++------------ .../rerank/VoyageAIRerankTaskSettings.java | 43 +++++++++++--- 3 files changed, 59 insertions(+), 45 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java index b960d2cf01e0b..eb6c7898e9e72 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java @@ -23,6 +23,7 @@ public record VoyageAIRerankRequestEntity(String model, String query, List * The response will look like (without whitespace): + * { + * "object": "list", + * "data": [ * { - * "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703", - * "results": [ - * { - * "document": { - * "text": "Washington, D.C. is the capital of the United States. It is a federal district." - * }, - * "index": 2, - * "relevance_score": 0.98005307 - * }, - * { - * "document": { - * "text": "Capital punishment (the death penalty) As of 2017, capital punishment is legal in 30 of the 50 states." - * }, - * "index": 3, - * "relevance_score": 0.27904198 - * }, - * { - * "document": { - * "text": "Carson City is the capital city of the American state of Nevada." - * }, - * "index": 0, - * "relevance_score": 0.10194652 - * } - * ], - * "usage": {"total_tokens": 15} + * "relevance_score": 0.4375, + * "index": 0 + * }, + * { + * "relevance_score": 0.421875, + * "index": 1 * } - * + * ], + * "model": "rerank-2", + * "usage": { + * "total_tokens": 26 + * } + * } * @param response the http response from VoyageAI * @return the parsed response * @throws IOException if there is an error parsing the response @@ -87,7 +76,7 @@ public static InferenceServiceResults fromResponse(HttpResult response) throws I XContentParser.Token token = jsonParser.currentToken(); ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); - positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE); + positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); token = jsonParser.currentToken(); if (token == XContentParser.Token.START_ARRAY) { @@ -123,13 +112,8 @@ private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser p break; case "document": parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - do { - if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) { - parser.nextToken(); // move to VALUE_STRING - documentText = parser.text(); - } - } while (parser.nextToken() != XContentParser.Token.END_OBJECT); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); + documentText = parser.text(); parser.nextToken();// move past END_OBJECT // parser should now be at the next FIELD_NAME or END_OBJECT break; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java index db5777065a160..0aa65abfda3d2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java @@ -35,9 +35,10 @@ public class VoyageAIRerankTaskSettings implements TaskSettings { public static final String NAME = "voyageai_rerank_task_settings"; + public static final String RETURN_DOCUMENTS = "return_documents"; public static final String TOP_K_DOCS_ONLY = "top_k"; - public static final VoyageAIRerankTaskSettings EMPTY_SETTINGS = new VoyageAIRerankTaskSettings(null, null); + public static final VoyageAIRerankTaskSettings EMPTY_SETTINGS = new VoyageAIRerankTaskSettings(null, null, null); public static VoyageAIRerankTaskSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); @@ -46,6 +47,7 @@ public static VoyageAIRerankTaskSettings fromMap(Map map) { return EMPTY_SETTINGS; } + Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException); Integer topKDocumentsOnly = extractOptionalPositiveInteger( map, TOP_K_DOCS_ONLY, @@ -63,7 +65,7 @@ public static VoyageAIRerankTaskSettings fromMap(Map map) { throw validationException; } - return of(topKDocumentsOnly, truncation); + return of(topKDocumentsOnly, returnDocuments, truncation); } /** @@ -81,30 +83,39 @@ public static VoyageAIRerankTaskSettings of( requestTaskSettings.getTopKDocumentsOnly() != null ? requestTaskSettings.getTopKDocumentsOnly() : originalSettings.getTopKDocumentsOnly(), + requestTaskSettings.getReturnDocuments() != null + ? requestTaskSettings.getReturnDocuments() + : originalSettings.getReturnDocuments(), requestTaskSettings.getTruncation() != null ? requestTaskSettings.getTruncation() : originalSettings.getTruncation() ); } - public static VoyageAIRerankTaskSettings of(Integer topKDocumentsOnly, Boolean truncation) { - return new VoyageAIRerankTaskSettings(topKDocumentsOnly, truncation); + public static VoyageAIRerankTaskSettings of(Integer topKDocumentsOnly, Boolean returnDocuments, Boolean truncation) { + return new VoyageAIRerankTaskSettings(topKDocumentsOnly, returnDocuments, truncation); } private final Integer topKDocumentsOnly; + private final Boolean returnDocuments; private final Boolean truncation; public VoyageAIRerankTaskSettings(StreamInput in) throws IOException { - this(in.readOptionalInt(), in.readOptionalBoolean()); + this(in.readOptionalInt(), in.readOptionalBoolean(), in.readOptionalBoolean()); } - public VoyageAIRerankTaskSettings(@Nullable Integer topKDocumentsOnly, @Nullable Boolean truncation) { + public VoyageAIRerankTaskSettings( + @Nullable Integer topKDocumentsOnly, + @Nullable Boolean doReturnDocuments, + @Nullable Boolean truncation + ) { this.topKDocumentsOnly = topKDocumentsOnly; + this.returnDocuments = doReturnDocuments; this.truncation = truncation; } @Override public boolean isEmpty() { - return topKDocumentsOnly == null && truncation == null; + return topKDocumentsOnly == null && returnDocuments == null && truncation == null; } @Override @@ -113,6 +124,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (topKDocumentsOnly != null) { builder.field(TOP_K_DOCS_ONLY, topKDocumentsOnly); } + if (returnDocuments != null) { + builder.field(RETURN_DOCUMENTS, returnDocuments); + } if (truncation != null) { builder.field(TRUNCATION, truncation); } @@ -133,6 +147,7 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeOptionalInt(topKDocumentsOnly); + out.writeOptionalBoolean(returnDocuments); out.writeOptionalBoolean(truncation); } @@ -141,12 +156,14 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; VoyageAIRerankTaskSettings that = (VoyageAIRerankTaskSettings) o; - return Objects.equals(topKDocumentsOnly, that.topKDocumentsOnly) && Objects.equals(truncation, that.truncation); + return Objects.equals(topKDocumentsOnly, that.topKDocumentsOnly) && + Objects.equals(returnDocuments, that.returnDocuments) && + Objects.equals(truncation, that.truncation); } @Override public int hashCode() { - return Objects.hash(truncation, topKDocumentsOnly); + return Objects.hash(truncation, returnDocuments, topKDocumentsOnly); } public static String invalidInputTypeMessage(InputType inputType) { @@ -157,6 +174,14 @@ public Integer getTopKDocumentsOnly() { return topKDocumentsOnly; } + public Boolean getDoesReturnDocuments() { + return returnDocuments; + } + + public Boolean getReturnDocuments() { + return returnDocuments; + } + public Boolean getTruncation() { return truncation; } From 91dee7fa2dcd0d5f68dc00d5393a1780586d86f1 Mon Sep 17 00:00:00 2001 From: fzowl Date: Mon, 3 Feb 2025 22:23:09 +0100 Subject: [PATCH 03/20] VoyageAI embeddings and rerank: - embeddings works, tested - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank) What's missing: - unit and integration tests --- server/src/main/java/org/elasticsearch/TransportVersions.java | 2 +- .../xpack/inference/services/voyageai/VoyageAIService.java | 2 +- .../inference/services/voyageai/VoyageAIServiceSettings.java | 2 +- .../services/voyageai/embeddings/VoyageAIEmbeddingType.java | 4 ++-- .../voyageai/rerank/VoyageAIRerankServiceSettings.java | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 5065d84f84978..c9fde8501621a 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -172,7 +172,6 @@ static TransportVersion def(int id) { public static final TransportVersion TIMEOUT_GET_PARAM_FOR_RESOLVE_CLUSTER = def(8_838_0_00); public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_0_00); public static final TransportVersion ML_INFERENCE_IBM_WATSONX_RERANK_ADDED = def(8_840_0_00); - public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_X = def(8_840_0_01); public static final TransportVersion ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES = def(9_002_0_00); @@ -180,6 +179,7 @@ static TransportVersion def(int id) { public static final TransportVersion REMOVE_DESIRED_NODE_VERSION = def(9_004_0_00); + public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_002_0_00); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index f9264d12877e3..63d1b8a4acc95 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -346,7 +346,7 @@ static SimilarityMeasure defaultSimilarity() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.JINA_AI_INTEGRATION_ADDED; + return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED; } public static class Configuration { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java index ea4efa7afabcf..9a58f39b000e5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java @@ -131,7 +131,7 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.JINA_AI_INTEGRATION_ADDED; + return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java index ee414b89bfe5c..432d8797ccfce 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java @@ -18,7 +18,7 @@ import java.util.Map; /** - * Defines the type of embedding that the cohere api should return for a request. + * Defines the type of embedding that the VoyageAI api should return for a request. * *

* See api docs for details. @@ -89,7 +89,7 @@ public static VoyageAIEmbeddingType fromElementType(DenseVectorFieldMapper.Eleme throw new IllegalArgumentException( Strings.format( - "Element type [%s] does not map to a Cohere embedding value, must be one of [%s]", + "Element type [%s] does not map to a VoyageAI embedding value, must be one of [%s]", elementType, String.join(", ", validElementTypes) ) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettings.java index 916236aabd246..1d3607922c5c2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettings.java @@ -90,7 +90,7 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.JINA_AI_INTEGRATION_ADDED; + return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED; } @Override From 6f11414171a299f8c2d940bc908f9071e631454b Mon Sep 17 00:00:00 2001 From: fzowl Date: Mon, 3 Feb 2025 22:27:09 +0100 Subject: [PATCH 04/20] VoyageAI embeddings and rerank: - embeddings works, tested - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank) What's missing: - unit and integration tests --- .../request/voyageai/VoyageAIEmbeddingsRequestEntity.java | 2 +- .../inference/services/voyageai/VoyageAIServiceSettings.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java index bc28efea7452f..ace5a90120513 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java @@ -17,7 +17,7 @@ import java.util.List; import java.util.Objects; -import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.invalidInputTypeMessage; +import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings.invalidInputTypeMessage; public record VoyageAIEmbeddingsRequestEntity( List input, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java index 9a58f39b000e5..76450913453ff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java @@ -38,7 +38,7 @@ public class VoyageAIServiceSettings extends FilteredXContentObject implements S public static final String NAME = "voyageai_service_settings"; public static final String MODEL_ID = "model_id"; private static final Logger logger = LogManager.getLogger(VoyageAIServiceSettings.class); - // See https://jina.ai/contact-sales/#rate-limit + // See https://docs.voyageai.com/docs/rate-limits public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(2_000); public static VoyageAIServiceSettings fromMap(Map map, ConfigurationParseContext context) { From 050d5b2ed071b00e5cfeb6d3443bb36c072e939b Mon Sep 17 00:00:00 2001 From: fzowl Date: Tue, 4 Feb 2025 18:54:12 +0100 Subject: [PATCH 05/20] Adding initial tests Moving dimensions to ServiceSettings --- .../VoyageAIEmbeddingsRequestEntity.java | 4 +- .../VoyageAIEmbeddingsTaskSettings.java | 40 +- .../VoyageAIServiceSettingsTests.java | 175 ++ .../voyageai/VoyageAIServiceTests.java | 1941 +++++++++++++++++ .../VoyageAIEmbeddingsModelTests.java | 171 ++ ...oyageAIEmbeddingsServiceSettingsTests.java | 188 ++ .../VoyageAIEmbeddingsTaskSettingsTests.java | 194 ++ .../rerank/VoyageAIRerankModelTests.java | 77 + .../VoyageAIRerankServiceSettingsTests.java | 84 + .../VoyageAIRerankTaskSettingsTests.java | 133 ++ 10 files changed, 2973 insertions(+), 34 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java index ace5a90120513..ebd68bef359a7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java @@ -56,8 +56,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(TRUNCATION_FIELD, taskSettings.getTruncation()); } - if(taskSettings.getOutputDimension() != null) { - builder.field(OUTPUT_DIMENSION, taskSettings.getOutputDimension()); + if(serviceSettings.dimensions() != null) { + builder.field(OUTPUT_DIMENSION, serviceSettings.dimensions()); } if(serviceSettings.getEmbeddingType() != null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java index f9cbb9a41099c..b89c414e26520 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java @@ -41,7 +41,7 @@ public class VoyageAIEmbeddingsTaskSettings implements TaskSettings { public static final String NAME = "voyageai_embeddings_task_settings"; - public static final VoyageAIEmbeddingsTaskSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsTaskSettings(null, null, null); + public static final VoyageAIEmbeddingsTaskSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsTaskSettings(null, null); static final String INPUT_TYPE = "input_type"; static final EnumSet VALID_REQUEST_VALUES = EnumSet.of( InputType.INGEST, @@ -79,7 +79,7 @@ public static VoyageAIEmbeddingsTaskSettings fromMap(Map map) { throw validationException; } - return new VoyageAIEmbeddingsTaskSettings(inputType, truncation, outputDimension); + return new VoyageAIEmbeddingsTaskSettings(inputType, truncation); } /** @@ -101,9 +101,8 @@ public static VoyageAIEmbeddingsTaskSettings of( ) { var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings, requestInputType); var truncationToUse = getValidTruncation(originalSettings, requestTaskSettings); - var outputDimension = getValidOutputDimension(originalSettings, requestTaskSettings); - return new VoyageAIEmbeddingsTaskSettings(inputTypeToUse, truncationToUse, outputDimension); + return new VoyageAIEmbeddingsTaskSettings(inputTypeToUse, truncationToUse); } private static InputType getValidInputType( @@ -129,36 +128,23 @@ private static Boolean getValidTruncation( return requestTaskSettings.getTruncation() == null ? originalSettings.truncation : requestTaskSettings.getTruncation(); } - private static Integer getValidOutputDimension( - VoyageAIEmbeddingsTaskSettings originalSettings, - VoyageAIEmbeddingsTaskSettings requestTaskSettings - ) { - return requestTaskSettings.getOutputDimension() == null - ? originalSettings.outputDimension - : requestTaskSettings.getOutputDimension(); - } - private final InputType inputType; private final Boolean truncation; - private final Integer outputDimension; public VoyageAIEmbeddingsTaskSettings(StreamInput in) throws IOException { this( in.readOptionalEnum(InputType.class), - in.readOptionalBoolean(), - in.readOptionalInt() + in.readOptionalBoolean() ); } public VoyageAIEmbeddingsTaskSettings( @Nullable InputType inputType, - @Nullable Boolean truncation, - @Nullable Integer outputDimension + @Nullable Boolean truncation ) { validateInputType(inputType); this.inputType = inputType; this.truncation = truncation; - this.outputDimension = outputDimension; } private static void validateInputType(InputType inputType) { @@ -171,7 +157,7 @@ private static void validateInputType(InputType inputType) { @Override public boolean isEmpty() { - return inputType == null && truncation == null && outputDimension == null; + return inputType == null && truncation == null; } @Override @@ -185,10 +171,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(TRUNCATION, truncation); } - if (outputDimension != null) { - builder.field(OUTPUT_DIMENSION, outputDimension); - } - builder.endObject(); return builder; } @@ -201,10 +183,6 @@ public Boolean getTruncation() { return truncation; } - public Integer getOutputDimension() { - return outputDimension; - } - @Override public String getWriteableName() { return NAME; @@ -219,7 +197,6 @@ public TransportVersion getMinimalSupportedVersion() { public void writeTo(StreamOutput out) throws IOException { out.writeOptionalEnum(inputType); out.writeOptionalBoolean(truncation); - out.writeOptionalInt(outputDimension); } @Override @@ -228,13 +205,12 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; VoyageAIEmbeddingsTaskSettings that = (VoyageAIEmbeddingsTaskSettings) o; return Objects.equals(inputType, that.inputType) && - Objects.equals(truncation, that.truncation) && - Objects.equals(outputDimension, that.outputDimension); + Objects.equals(truncation, that.truncation); } @Override public int hashCode() { - return Objects.hash(inputType, truncation, outputDimension); + return Objects.hash(inputType, truncation); } public static String invalidInputTypeMessage(InputType inputType) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java new file mode 100644 index 0000000000000..1fad1c6554da6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java @@ -0,0 +1,175 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class VoyageAIServiceSettingsTests extends AbstractWireSerializingTestCase { + + public static VoyageAIServiceSettings createRandomWithNonNullUrl() { + return createRandom(randomAlphaOfLength(15)); + } + + /** + * The created settings can have a url set to null. + */ + public static VoyageAIServiceSettings createRandom() { + var url = randomBoolean() ? randomAlphaOfLength(15) : null; + return createRandom(url); + } + + private static VoyageAIServiceSettings createRandom(String url) { + var model = randomAlphaOfLength(15); + + return new VoyageAIServiceSettings(ServiceUtils.createOptionalUri(url), model, RateLimitSettingsTests.createRandom()); + } + + public void testFromMap() { + var url = "https://www.abc.com"; + var model = "model"; + var serviceSettings = VoyageAIServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.URL, url, VoyageAIServiceSettings.MODEL_ID, model)), + ConfigurationParseContext.REQUEST + ); + + MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, null))); + } + + public void testFromMap_WithRateLimit() { + var url = "https://www.abc.com"; + var model = "model"; + var serviceSettings = VoyageAIServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + url, + VoyageAIServiceSettings.MODEL_ID, + model, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3)) + ) + ), + ConfigurationParseContext.REQUEST + ); + + MatcherAssert.assertThat( + serviceSettings, + is(new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, new RateLimitSettings(3))) + ); + } + + public void testFromMap_WhenUsingModelId() { + var url = "https://www.abc.com"; + var model = "model"; + var serviceSettings = VoyageAIServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.URL, url, VoyageAIServiceSettings.MODEL_ID, model)), + ConfigurationParseContext.PERSISTENT + ); + + MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, null))); + } + + public void testFromMap_MissingUrl_DoesNotThrowException() { + var serviceSettings = VoyageAIServiceSettings.fromMap( + new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, "model")), + ConfigurationParseContext.PERSISTENT + ); + assertNull(serviceSettings.uri()); + } + + public void testFromMap_EmptyUrl_ThrowsError() { + var thrownException = expectThrows( + ValidationException.class, + () -> VoyageAIServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "")), ConfigurationParseContext.PERSISTENT) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value empty string. [%s] must be a non-empty string;", + ServiceFields.URL + ) + ) + ); + } + + public void testFromMap_InvalidUrl_ThrowsError() { + var url = "https://www.abc^.com"; + var thrownException = expectThrows( + ValidationException.class, + () -> VoyageAIServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)), ConfigurationParseContext.PERSISTENT) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", url, ServiceFields.URL) + ) + ); + } + + public void testXContent_WritesModelId() throws IOException { + var entity = new VoyageAIServiceSettings((String) null, "model", new RateLimitSettings(1)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"model_id":"model","rate_limit":{"requests_per_minute":1}}""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return VoyageAIServiceSettings::new; + } + + @Override + protected VoyageAIServiceSettings createTestInstance() { + return createRandomWithNonNullUrl(); + } + + @Override + protected VoyageAIServiceSettings mutateInstance(VoyageAIServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, VoyageAIServiceSettingsTests::createRandom); + } + + public static Map getServiceSettingsMap(@Nullable String url, String model) { + var map = new HashMap(); + + if (url != null) { + map.put(ServiceFields.URL, url); + } + + map.put(VoyageAIServiceSettings.MODEL_ID, model); + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java new file mode 100644 index 0000000000000..1907599f021ab --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -0,0 +1,1941 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModelTests; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class VoyageAIServiceTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModel() throws IOException { + try (var service = createVoyageAIService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null, null))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), + getSecretSettingsMap("secret") + ), + modelListener + ); + + } + } + + public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + try (var service = createVoyageAIService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelListener + ); + + } + } + + public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + try (var service = createVoyageAIService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), + getSecretSettingsMap("secret") + ), + modelListener + ); + + } + } + + public void testParseRequestConfig_OptionalTaskSettings() throws IOException { + try (var service = createVoyageAIService()) { + + ActionListener modelListener = ActionListener.wrap(model -> { + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), equalTo(VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + getSecretSettingsMap("secret") + ), + modelListener + ); + + } + } + + public void testParseRequestConfig_ThrowsUnsupportedTaskType() throws IOException { + try (var service = createVoyageAIService()) { + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "The [voyageai] service does not support task type [sparse_embedding]" + ); + + service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ), + failureListener + ); + } + } + + private static ActionListener getModelListenerForException(Class exceptionClass, String expectedMessage) { + return ActionListener.wrap((model) -> fail("Model parsing should have failed"), e -> { + MatcherAssert.assertThat(e, instanceOf(exceptionClass)); + MatcherAssert.assertThat(e.getMessage(), is(expectedMessage)); + }); + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createVoyageAIService()) { + var config = getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ); + config.put("extra_key", "value"); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createVoyageAIService()) { + var serviceSettings = VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + serviceSettings, + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + try (var service = createVoyageAIService()) { + var taskSettingsMap = VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST); + taskSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + taskSettingsMap, + getSecretSettingsMap("secret") + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + try (var service = createVoyageAIService()) { + var secretSettingsMap = getSecretSettingsMap("secret"); + secretSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + secretSettingsMap + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [voyageai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + } + } + + public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModelWithoutUrl() throws IOException { + try (var service = createVoyageAIService()) { + var modelListener = ActionListener.wrap((model) -> { + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri()); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, (e) -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ), + modelListener + ); + + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModel() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "oldmodel"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [voyageai] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelWithoutUrl() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri()); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH), + getSecretSettingsMap("secret") + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + try (var service = createVoyageAIService()) { + var secretSettingsMap = getSecretSettingsMap("secret"); + secretSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + secretSettingsMap + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ); + persistedConfig.secrets().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createVoyageAIService()) { + var serviceSettingsMap = VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createVoyageAIService()) { + var taskSettingsMap = VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH); + taskSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + taskSettingsMap, + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModel() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model_old"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty() + ); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [voyageai] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWithoutUrl() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri()); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createVoyageAIService()) { + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty() + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createVoyageAIService()) { + var serviceSettingsMap = VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH))); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createVoyageAIService()) { + var taskSettingsMap = VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST); + taskSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + taskSettingsMap + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); + + var embeddingsModel = (VoyageAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST))); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotVoyageAIModel() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name"); + + try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testCheckModelConfig_UpdatesDimensions() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "voyage-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 1, + "voyage-clip-v2" + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat( + result, + // the dimension is set to 2 because there are 2 embeddings returned from the mock server + is( + VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 2, + "voyage-clip-v2" + ) + ) + ); + } + } + + public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "voyage-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 1, + "voyage-clip-v2", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat( + result, + // the dimension is set to 2 because there are 2 embeddings returned from the mock server + is( + VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 2, + "voyage-clip-v2", + SimilarityMeasure.DOT_PRODUCT + ) + ) + ); + } + } + + public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "voyage-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 1, + "voyage-clip-v2", + SimilarityMeasure.COSINE + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat( + result, + // the dimension is set to 2 because there are 2 embeddings returned from the mock server + is( + VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 2, + "voyage-clip-v2", + SimilarityMeasure.COSINE + ) + ) + ); + } + } + + public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { + testUpdateModelWithEmbeddingDetails_Successful(null); + } + + public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException { + testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values())); + } + + private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + var embeddingSize = randomNonNegativeInt(); + var model = VoyageAIEmbeddingsModelTests.createModel( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + randomNonNegativeInt(), + randomNonNegativeInt(), + randomAlphaOfLength(10), + similarityMeasure + ); + + Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); + + SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? + VoyageAIService.defaultSimilarity() : similarityMeasure; + assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity()); + assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); + } + } + + public void testInfer_Embedding_UnauthorisedResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "detail": "Unauthorized" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "model", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + MatcherAssert.assertThat(error.getMessage(), containsString("Error message: [Unauthorized]")); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + } + } + + public void testInfer_Rerank_UnauthorisedResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "detail": "Unauthorized" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "model", 1024, false); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + MatcherAssert.assertThat(error.getMessage(), containsString("Error message: [Unauthorized]")); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + } + } + + public void testInfer_Embedding_Get_Response_Ingest() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "voyage-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "voyage-clip-v2", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "voyage-clip-v2", "task", "retrieval.passage"))); + } + } + + public void testInfer_Embedding_Get_Response_Search() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "voyage-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "voyage-clip-v2", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "voyage-clip-v2", "task", "retrieval.query"))); + } + } + + public void testInfer_Embedding_Get_Response_clustering() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + {"model":"voyage-clip-v2","object":"list","usage":{"total_tokens":5,"prompt_tokens":5}, + "data":[{"object":"embedding","index":0,"embedding":[0.123, -0.123]}]} + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "voyage-clip-v2", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.CLUSTERING, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "voyage-clip-v2", "task", "separation"))); + } + } + + public void testInfer_Embedding_Get_Response_NullInputType() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "voyage-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "voyage-clip-v2", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "voyage-clip-v2"))); + } + } + + public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOException { + String responseJson = """ + { + "model": "model", + "results": [ + { + "index": 2, + "relevance_score": 0.98005307 + }, + { + "index": 1, + "relevance_score": 0.27904198 + }, + { + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2", "candidate3"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("index", 2, "relevance_score", 0.98005307F)), + Map.of("ranked_doc", Map.of("index", 1, "relevance_score", 0.27904198F)), + Map.of("ranked_doc", Map.of("index", 0, "relevance_score", 0.10194652F)) + ) + ) + ) + ); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "query", + "query", + "documents", + List.of("candidate1", "candidate2", "candidate3"), + "model", + "model", + "return_documents", + false + ) + ) + ); + + } + } + + public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOException { + String responseJson = """ + { + "model": "model", + "results": [ + { + "index": 2, + "relevance_score": 0.98005307 + }, + { + "index": 1, + "relevance_score": 0.27904198 + }, + { + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("index", 2, "relevance_score", 0.98005307F)), + Map.of("ranked_doc", Map.of("index", 1, "relevance_score", 0.27904198F)), + Map.of("ranked_doc", Map.of("index", 0, "relevance_score", 0.10194652F)) + ) + ) + ) + ); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "query", + "query", + "documents", + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + "model", + "model", + "return_documents", + false, + "top_n", + 3 + ) + ) + ); + + } + + } + + public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IOException { + String responseJson = """ + { + "model": "model", + "results": [ + { + "index": 2, + "relevance_score": 0.98005307, + "document": { + "text": "candidate3" + } + }, + { + "index": 1, + "relevance_score": 0.27904198, + "document": { + "text": "candidate2" + } + }, + { + "index": 0, + "relevance_score": 0.10194652, + "document": { + "text": "candidate1" + } + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2", "candidate3"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("text", "candidate3", "index", 2, "relevance_score", 0.98005307F)), + Map.of("ranked_doc", Map.of("text", "candidate2", "index", 1, "relevance_score", 0.27904198F)), + Map.of("ranked_doc", Map.of("text", "candidate1", "index", 0, "relevance_score", 0.10194652F)) + ) + ) + ) + ); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is(Map.of("query", "query", "documents", List.of("candidate1", "candidate2", "candidate3"), "model", "model")) + ); + + } + + } + + public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOException { + String responseJson = """ + { + "model": "model", + "results": [ + { + "index": 2, + "relevance_score": 0.98005307, + "document": { + "text": "candidate3" + } + }, + { + "index": 1, + "relevance_score": 0.27904198, + "document": { + "text": "candidate2" + } + }, + { + "index": 0, + "relevance_score": 0.10194652, + "document": { + "text": "candidate1" + } + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("text", "candidate3", "index", 2, "relevance_score", 0.98005307F)), + Map.of("ranked_doc", Map.of("text", "candidate2", "index", 1, "relevance_score", 0.27904198F)), + Map.of("ranked_doc", Map.of("text", "candidate1", "index", 0, "relevance_score", 0.10194652F)) + ) + ) + ) + ); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "query", + "query", + "documents", + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + "model", + "model", + "return_documents", + true, + "top_n", + 3 + ) + ) + ); + + } + + } + + public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "voyage-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new VoyageAIEmbeddingsTaskSettings((InputType) null, null, null), + 1024, + 1024, + "voyage-clip-v2", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "voyage-clip-v2"))); + } + } + + public void test_Embedding_ChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException { + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new VoyageAIEmbeddingsTaskSettings((InputType) null), + createRandomChunkingSettings(), + 1024, + 1024, + "voyage-clip-v2" + ); + + test_Embedding_ChunkedInfer_BatchesCalls(model); + } + + public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOException { + var model = VoyageAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new VoyageAIEmbeddingsTaskSettings((InputType) null), + null, + 1024, + 1024, + "voyage-clip-v2" + ); + + test_Embedding_ChunkedInfer_BatchesCalls(model); + } + + private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel model) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + + // Batching will call the service with 2 input + String responseJson = """ + { + "model": "voyage-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + }, + { + "object": "embedding", + "index": 1, + "embedding": [ + 0.223, + -0.223 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + // 2 input + service.chunkedInfer( + model, + null, + List.of("foo", "bar"), + new HashMap<>(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(2)); + { + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals("foo", floatResult.chunks().get(0).matchedText()); + assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f); + } + { + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals("bar", floatResult.chunks().get(0).matchedText()); + assertArrayEquals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding(), 0.0f); + } + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("foo", "bar"), "model", "voyage-clip-v2"))); + } + } + + public void testDefaultSimilarity() { + assertEquals(SimilarityMeasure.DOT_PRODUCT, VoyageAIService.defaultSimilarity()); + } + + @SuppressWarnings("checkstyle:LineLength") + public void testGetConfiguration() throws Exception { + try (var service = createVoyageAIService()) { + String content = XContentHelper.stripWhitespace( + """ + { + "service": "voyageai", + "name": "Voyage AI", + "task_types": ["text_embedding", "rerank"], + "configurations": { + "api_key": { + "description": "API Key for the provider you're connecting to.", + "label": "API Key", + "required": true, + "sensitive": true, + "updatable": true, + "type": "str", + "supported_task_types": ["text_embedding", "rerank"] + }, + "dimensions": { + "description": "The number of dimensions the resulting embeddings should have. For more information refer to https://api.voyage.ai/redoc#tag/embeddings/operation/create_embedding_v1_embeddings_post.", + "label": "Dimensions", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding"] + }, + "model_id": { + "description": "The name of the model to use for the inference task.", + "label": "Model ID", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "rerank"] + }, + "rate_limit.requests_per_minute": { + "description": "Minimize the number of rate limit errors.", + "label": "Rate Limit", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "rerank"] + } + } + } + """ + ); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = service.getConfiguration(); + assertToXContentEquivalent( + originalBytes, + toXContent(serviceConfiguration, XContentType.JSON, humanReadable), + XContentType.JSON + ); + } + } + + public void testDoesNotSupportsStreaming() throws IOException { + try (var service = new VoyageAIService(mock(), createWithEmptySettings(mock()))) { + assertFalse(service.canStream(TaskType.COMPLETION)); + assertFalse(service.canStream(TaskType.ANY)); + } + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>( + Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) + ); + } + + private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings)); + } + + private VoyageAIService createVoyageAIService() { + return new VoyageAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java new file mode 100644 index 0000000000000..c11748a122b5a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java @@ -0,0 +1,171 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.embeddings; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.hamcrest.MatcherAssert; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap; +import static org.hamcrest.Matchers.is; + +public class VoyageAIEmbeddingsModelTests extends ESTestCase { + + public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty_AndInputTypeIsInvalid() { + var model = createModel("url", "api_key", null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, Map.of(), InputType.UNSPECIFIED); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull_AndInputTypeIsInvalid() { + var model = createModel("url", "api_key", null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, null, InputType.UNSPECIFIED); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public void testOverrideWith_SetsInputTypeToIngest_WhenTheFieldIsNullInModelTaskSettings_AndNullInRequestTaskSettings() { + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.INGEST); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingStoredTaskSettings() { + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.SEARCH); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingRequestTaskSettings() { + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.INGEST), InputType.SEARCH); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_WhenRequestInputTypeIsInvalid() { + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.SEARCH), InputType.UNSPECIFIED); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_DoesNotSetInputType_FromRequest_IfInputTypeIsInvalid() { + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_DoesNotSetInputType_WhenRequestTaskSettingsIsNull_AndRequestInputTypeIsInvalid() { + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + + var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public static VoyageAIEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit, @Nullable String model) { + return createModel(url, apiKey, VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null, model); + } + + public static VoyageAIEmbeddingsModel createModel( + String url, + String apiKey, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model + ) { + return createModel(url, apiKey, VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions, model); + } + + public static VoyageAIEmbeddingsModel createModel( + String url, + String apiKey, + VoyageAIEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model + ) { + return new VoyageAIEmbeddingsModel( + "id", + "service", + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(url, model, null), + SimilarityMeasure.DOT_PRODUCT, + dimensions, + tokenLimit + ), + taskSettings, + chunkingSettings, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static VoyageAIEmbeddingsModel createModel( + String url, + String apiKey, + VoyageAIEmbeddingsTaskSettings taskSettings, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model + ) { + return new VoyageAIEmbeddingsModel( + "id", + "service", + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(url, model, null), + SimilarityMeasure.DOT_PRODUCT, + dimensions, + tokenLimit + ), + taskSettings, + null, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static VoyageAIEmbeddingsModel createModel( + String url, + String apiKey, + VoyageAIEmbeddingsTaskSettings taskSettings, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model, + @Nullable SimilarityMeasure similarityMeasure + ) { + return new VoyageAIEmbeddingsModel( + "id", + "service", + new VoyageAIEmbeddingsServiceSettings(new VoyageAIServiceSettings(url, model, null), similarityMeasure, dimensions, tokenLimit), + taskSettings, + null, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..62c799fae1557 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java @@ -0,0 +1,188 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class VoyageAIEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { + public static VoyageAIEmbeddingsServiceSettings createRandom() { + SimilarityMeasure similarityMeasure = null; + Integer dims = null; + similarityMeasure = SimilarityMeasure.DOT_PRODUCT; + dims = 1024; + Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); + + var commonSettings = VoyageAIServiceSettingsTests.createRandom(); + + return new VoyageAIEmbeddingsServiceSettings(commonSettings, similarityMeasure, dims, maxInputTokens); + } + + public void testFromMap() { + var url = "https://www.abc.com"; + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var dims = 1536; + var maxInputTokens = 512; + var model = "model"; + var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + url, + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + VoyageAIServiceSettings.MODEL_ID, + model + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, null), + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens + ) + ) + ); + } + + public void testFromMap_WithModelId() { + var url = "https://www.abc.com"; + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var dims = 1536; + var maxInputTokens = 512; + var model = "model"; + var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + url, + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + VoyageAIServiceSettings.MODEL_ID, + model + ) + ), + ConfigurationParseContext.REQUEST + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, null), + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens + ) + ) + ); + } + + public void testFromMap_InvalidSimilarity_ThrowsError() { + var similarity = "by_size"; + var thrownException = expectThrows( + ValidationException.class, + () -> VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, "model", ServiceFields.SIMILARITY, similarity)), + ConfigurationParseContext.PERSISTENT + ) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is( + "Validation Failed: 1: [service_settings] Invalid value [by_size] received. [similarity] " + + "must be one of [cosine, dot_product, l2_norm];" + ) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings("url", "model", new RateLimitSettings(3)), + SimilarityMeasure.COSINE, + 5, + 10 + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + assertThat(xContentResult, is(""" + {"url":"url","model_id":"model",""" + """ + "rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10}""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return VoyageAIEmbeddingsServiceSettings::new; + } + + @Override + protected VoyageAIEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected VoyageAIEmbeddingsServiceSettings mutateInstance(VoyageAIEmbeddingsServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, VoyageAIEmbeddingsServiceSettingsTests::createRandom); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + + public static Map getServiceSettingsMap(@Nullable String url, String model) { + var map = new HashMap<>(VoyageAIServiceSettingsTests.getServiceSettingsMap(url, model)); + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettingsTests.java new file mode 100644 index 0000000000000..c7ed88cf51114 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettingsTests.java @@ -0,0 +1,194 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithoutUnspecified; +import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings.VALID_REQUEST_VALUES; +import static org.hamcrest.Matchers.is; + +public class VoyageAIEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase { + + public static VoyageAIEmbeddingsTaskSettings createRandom() { + var inputType = randomBoolean() ? randomWithoutUnspecified() : null; + + return new VoyageAIEmbeddingsTaskSettings(inputType); + } + + public void testIsEmpty() { + var randomSettings = createRandom(); + var stringRep = Strings.toString(randomSettings); + assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); + } + + public void testUpdatedTaskSettings_NotUpdated_UseInitialSettings() { + var initialSettings = createRandom(); + var newSettings = new VoyageAIEmbeddingsTaskSettings((InputType) null); + Map newSettingsMap = new HashMap<>(); + VoyageAIEmbeddingsTaskSettings updatedSettings = (VoyageAIEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( + Collections.unmodifiableMap(newSettingsMap) + ); + assertEquals(initialSettings.getInputType(), updatedSettings.getInputType()); + } + + public void testUpdatedTaskSettings_Updated_UseNewSettings() { + var initialSettings = createRandom(); + var newSettings = new VoyageAIEmbeddingsTaskSettings(randomWithoutUnspecified()); + Map newSettingsMap = new HashMap<>(); + newSettingsMap.put(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, newSettings.getInputType().toString()); + VoyageAIEmbeddingsTaskSettings updatedSettings = (VoyageAIEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( + Collections.unmodifiableMap(newSettingsMap) + ); + assertEquals(newSettings.getInputType(), updatedSettings.getInputType()); + } + + public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() { + MatcherAssert.assertThat( + VoyageAIEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())), + is(new VoyageAIEmbeddingsTaskSettings((InputType) null)) + ); + } + + public void testFromMap_CreatesEmptySettings_WhenMapIsNull() { + MatcherAssert.assertThat(VoyageAIEmbeddingsTaskSettings.fromMap(null), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + } + + public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() { + MatcherAssert.assertThat( + VoyageAIEmbeddingsTaskSettings.fromMap( + new HashMap<>(Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString())) + ), + is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST)) + ); + } + + public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() { + var exception = expectThrows( + ValidationException.class, + () -> VoyageAIEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, "abc"))) + ); + + MatcherAssert.assertThat( + exception.getMessage(), + is( + Strings.format( + "Validation Failed: 1: [task_settings] Invalid value [abc] received. [input_type] must be one of [%s];", + getValidValuesSortedAndCombined(VALID_REQUEST_VALUES) + ) + ) + ); + } + + public void testFromMap_ReturnsFailure_WhenInputTypeIsUnspecified() { + var exception = expectThrows( + ValidationException.class, + () -> VoyageAIEmbeddingsTaskSettings.fromMap( + new HashMap<>(Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.UNSPECIFIED.toString())) + ) + ); + + MatcherAssert.assertThat( + exception.getMessage(), + is( + Strings.format( + "Validation Failed: 1: [task_settings] Invalid value [unspecified] received. [input_type] must be one of [%s];", + getValidValuesSortedAndCombined(VALID_REQUEST_VALUES) + ) + ) + ); + } + + private static > String getValidValuesSortedAndCombined(EnumSet validValues) { + var validValuesAsStrings = validValues.stream().map(value -> value.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new); + Arrays.sort(validValuesAsStrings); + + return String.join(", ", validValuesAsStrings); + } + + public void testXContent_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { + var thrownException = expectThrows(AssertionError.class, () -> new VoyageAIEmbeddingsTaskSettings(InputType.UNSPECIFIED)); + MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]")); + } + + public void testOf_KeepsOriginalValuesWhenRequestSettingsAreNull_AndRequestInputTypeIsInvalid() { + var taskSettings = new VoyageAIEmbeddingsTaskSettings(InputType.INGEST); + var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of( + taskSettings, + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + InputType.UNSPECIFIED + ); + MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings)); + } + + public void testOf_UsesRequestTaskSettings() { + var taskSettings = new VoyageAIEmbeddingsTaskSettings((InputType) null); + var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of( + taskSettings, + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST), + InputType.UNSPECIFIED + ); + + MatcherAssert.assertThat(overriddenTaskSettings, is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST))); + } + + public void testOf_UsesRequestTaskSettings_AndRequestInputType() { + var taskSettings = new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH); + var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of( + taskSettings, + new VoyageAIEmbeddingsTaskSettings((InputType) null), + InputType.INGEST + ); + + MatcherAssert.assertThat(overriddenTaskSettings, is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST))); + } + + @Override + protected Writeable.Reader instanceReader() { + return VoyageAIEmbeddingsTaskSettings::new; + } + + @Override + protected VoyageAIEmbeddingsTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected VoyageAIEmbeddingsTaskSettings mutateInstance(VoyageAIEmbeddingsTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, VoyageAIEmbeddingsTaskSettingsTests::createRandom); + } + + public static Map getTaskSettingsMapEmpty() { + return new HashMap<>(); + } + + public static Map getTaskSettingsMap(@Nullable InputType inputType) { + var map = new HashMap(); + + if (inputType != null) { + map.put(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, inputType.toString()); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java new file mode 100644 index 0000000000000..a05d95cdbba42 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.rerank; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +public class VoyageAIRerankModelTests { + + public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nullable Integer topN) { + return new VoyageAIRerankModel( + "id", + "service", + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), + new VoyageAIRerankTaskSettings(topN, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topN) { + return new VoyageAIRerankModel( + "id", + "service", + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), + new VoyageAIRerankTaskSettings(topN, null), + new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) + ); + } + + public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topN, Boolean returnDocuments) { + return new VoyageAIRerankModel( + "id", + "service", + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), + new VoyageAIRerankTaskSettings(topN, returnDocuments), + new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) + ); + } + + public static VoyageAIRerankModel createModel(String url, String modelId, @Nullable Integer topN, Boolean returnDocuments) { + return new VoyageAIRerankModel( + "id", + "service", + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(url, modelId, null)), + new VoyageAIRerankTaskSettings(topN, returnDocuments), + new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) + ); + } + + public static VoyageAIRerankModel createModel( + String url, + String apiKey, + String modelId, + @Nullable Integer topN, + Boolean returnDocuments + ) { + return new VoyageAIRerankModel( + "id", + "service", + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(url, modelId, null)), + new VoyageAIRerankTaskSettings(topN, returnDocuments), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java new file mode 100644 index 0000000000000..64f0f527c8a77 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class VoyageAIRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + public static VoyageAIRerankServiceSettings createRandom() { + return new VoyageAIRerankServiceSettings( + new VoyageAIServiceSettings( + randomFrom(new String[] { null, Strings.format("http://%s.com", randomAlphaOfLength(8)) }), + randomAlphaOfLength(10), + RateLimitSettingsTests.createRandom() + ) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var url = "http://www.abc.com"; + var model = "model"; + + var serviceSettings = new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(url, model, null)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "url":"http://www.abc.com", + "model_id":"model", + "rate_limit": { + "requests_per_minute": 2000 + } + } + """)); + } + + @Override + protected Writeable.Reader instanceReader() { + return VoyageAIRerankServiceSettings::new; + } + + @Override + protected VoyageAIRerankServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected VoyageAIRerankServiceSettings mutateInstance(VoyageAIRerankServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, VoyageAIRerankServiceSettingsTests::createRandom); + } + + @Override + protected VoyageAIRerankServiceSettings mutateInstanceForVersion(VoyageAIRerankServiceSettings instance, TransportVersion version) { + return instance; + } + + public static Map getServiceSettingsMap(@Nullable String url, @Nullable String model) { + return new HashMap<>(VoyageAIServiceSettingsTests.getServiceSettingsMap(url, model)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java new file mode 100644 index 0000000000000..b471bdeae898f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.voyageai.rerank; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; + +public class VoyageAIRerankTaskSettingsTests extends AbstractWireSerializingTestCase { + + public static VoyageAIRerankTaskSettings createRandom() { + var returnDocuments = randomBoolean() ? randomBoolean() : null; + var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null; + + return new VoyageAIRerankTaskSettings(topNDocsOnly, returnDocuments); + } + + public void testFromMap_WithValidValues_ReturnsSettings() { + Map taskMap = Map.of(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, true, VoyageAIRerankTaskSettings.TOP_N_DOCS_ONLY, 5); + var settings = VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap)); + assertTrue(settings.getReturnDocuments()); + assertEquals(5, settings.getTopNDocumentsOnly().intValue()); + } + + public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() { + var settings = VoyageAIRerankTaskSettings.fromMap(Map.of()); + assertNull(settings.getReturnDocuments()); + assertNull(settings.getTopNDocumentsOnly()); + } + + public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() { + Map taskMap = Map.of( + VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, + "invalid", + VoyageAIRerankTaskSettings.TOP_N_DOCS_ONLY, + 5 + ); + var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type")); + } + + public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { + Map taskMap = Map.of( + VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, + true, + VoyageAIRerankTaskSettings.TOP_N_DOCS_ONLY, + "invalid" + ); + var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [top_n] is not of the expected type")); + } + + public void testUpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { + var initialSettings = new VoyageAIRerankTaskSettings(5, true); + VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of()); + assertEquals(initialSettings, updatedSettings); + } + + public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() { + var initialSettings = new VoyageAIRerankTaskSettings(5, true); + Map newSettings = Map.of(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, false); + VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(initialSettings.getTopNDocumentsOnly(), updatedSettings.getTopNDocumentsOnly()); + } + + public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() { + var initialSettings = new VoyageAIRerankTaskSettings(5, true); + Map newSettings = Map.of(VoyageAIRerankTaskSettings.TOP_N_DOCS_ONLY, 7); + VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue()); + assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments()); + } + + public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() { + var initialSettings = new VoyageAIRerankTaskSettings(5, true); + Map newSettings = Map.of( + VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, + false, + VoyageAIRerankTaskSettings.TOP_N_DOCS_ONLY, + 7 + ); + VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue()); + } + + @Override + protected Writeable.Reader instanceReader() { + return VoyageAIRerankTaskSettings::new; + } + + @Override + protected VoyageAIRerankTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected VoyageAIRerankTaskSettings mutateInstance(VoyageAIRerankTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, VoyageAIRerankTaskSettingsTests::createRandom); + } + + public static Map getTaskSettingsMapEmpty() { + return new HashMap<>(); + } + + public static Map getTaskSettingsMap(@Nullable Integer topNDocumentsOnly, Boolean returnDocuments) { + var map = new HashMap(); + + if (topNDocumentsOnly != null) { + map.put(VoyageAIRerankTaskSettings.TOP_N_DOCS_ONLY, topNDocumentsOnly.toString()); + } + + if (returnDocuments != null) { + map.put(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments.toString()); + } + + return map; + } +} From b94bc5fc6bf20b37cea452448fa595984315b44e Mon Sep 17 00:00:00 2001 From: fzowl Date: Wed, 5 Feb 2025 17:53:28 +0100 Subject: [PATCH 06/20] Correcting the TransportVersions.java --- server/src/main/java/org/elasticsearch/TransportVersions.java | 1 + 1 file changed, 1 insertion(+) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index c9fde8501621a..f822f57f78f3d 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -172,6 +172,7 @@ static TransportVersion def(int id) { public static final TransportVersion TIMEOUT_GET_PARAM_FOR_RESOLVE_CLUSTER = def(8_838_0_00); public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_0_00); public static final TransportVersion ML_INFERENCE_IBM_WATSONX_RERANK_ADDED = def(8_840_0_00); + public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_X = def(8_840_0_01); public static final TransportVersion ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES = def(9_002_0_00); From be1e9cf273b68c847aea77389bbb72467f0434ce Mon Sep 17 00:00:00 2001 From: fzowl Date: Wed, 5 Feb 2025 18:18:20 +0100 Subject: [PATCH 07/20] Correcting due to comments --- .../VoyageAIEmbeddingsRequestEntity.java | 5 +++-- .../embeddings/VoyageAIEmbeddingType.java | 16 ---------------- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java index ebd68bef359a7..9021148b5acd4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java @@ -48,8 +48,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(INPUT_FIELD, input); builder.field(MODEL_FIELD, model); - if (taskSettings.getInputType() != null) { - builder.field(INPUT_TYPE_FIELD, convertToString(taskSettings.getInputType())); + var inputType = convertToString(taskSettings.getInputType()); + if (inputType != null) { + builder.field(INPUT_TYPE_FIELD, inputType); } if(taskSettings.getTruncation() != null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java index 432d8797ccfce..43be526e7e7b6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java @@ -102,20 +102,4 @@ public static VoyageAIEmbeddingType fromElementType(DenseVectorFieldMapper.Eleme public DenseVectorFieldMapper.ElementType toElementType() { return elementType; } - - /** - * Returns an embedding type that is known based on the transport version provided. If the embedding type enum was not yet - * introduced it will be defaulted INT8. - * - * @param embeddingType the value to translate if necessary - * @param version the version that dictates the translation - * @return the embedding type that is known to the version passed in - */ - public static VoyageAIEmbeddingType translateToVersion(VoyageAIEmbeddingType embeddingType, TransportVersion version) { - if (version.before(TransportVersions.V_8_14_0) && embeddingType == BYTE) { - return INT8; - } - - return embeddingType; - } } From 71dfdc8d85162d779224e4b7e5857aa204781add Mon Sep 17 00:00:00 2001 From: fzowl Date: Wed, 5 Feb 2025 18:40:46 +0100 Subject: [PATCH 08/20] Adding BIT support --- .../VoyageAIEmbeddingsResponseEntity.java | 22 ++++++++++++++----- .../embeddings/VoyageAIEmbeddingType.java | 17 ++++++++++---- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java index 38ee83d2956ba..58c3809f8975a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java @@ -16,6 +16,8 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; @@ -43,7 +45,9 @@ public class VoyageAIEmbeddingsResponseEntity { toLowerCase(VoyageAIEmbeddingType.FLOAT), VoyageAIEmbeddingsResponseEntity::parseFloatEmbeddingsArray, toLowerCase(VoyageAIEmbeddingType.INT8), - VoyageAIEmbeddingsResponseEntity::parseByteEmbeddingsArray + VoyageAIEmbeddingsResponseEntity::parseByteEmbeddingsArray, + toLowerCase(VoyageAIEmbeddingType.BINARY), + VoyageAIEmbeddingsResponseEntity::parseBitEmbeddingsArray ); private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes(); @@ -119,7 +123,7 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r return new InferenceTextEmbeddingFloatResults(embeddingList); } else if(embeddingType == VoyageAIEmbeddingType.INT8) { - List embeddingList = parseList( + List embeddingList = parseList( jsonParser, VoyageAIEmbeddingsResponseEntity::parseEmbeddingObjectByte ); @@ -144,7 +148,7 @@ private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseE return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList); } - private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding parseEmbeddingObjectByte(XContentParser parser) + private static InferenceByteEmbedding parseEmbeddingObjectByte(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -154,7 +158,13 @@ private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding parseEmb // parse and discard the rest of the object consumeUntilObjectEnd(parser); - return InferenceTextEmbeddingByteResults.InferenceByteEmbedding.of(embeddingValuesList); + return InferenceByteEmbedding.of(embeddingValuesList); + } + + private static InferenceServiceResults parseBitEmbeddingsArray(XContentParser parser) throws IOException { + var embeddingList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseByteArrayEntry); + + return new InferenceTextEmbeddingBitResults(embeddingList); } private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException { @@ -163,11 +173,11 @@ private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser p return new InferenceTextEmbeddingByteResults(embeddingList); } - private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException { + private static InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); List embeddingValuesList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseEmbeddingInt8Entry); - return InferenceTextEmbeddingByteResults.InferenceByteEmbedding.of(embeddingValuesList); + return InferenceByteEmbedding.of(embeddingValuesList); } private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java index 43be526e7e7b6..efa72ff2e7980 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.inference.services.voyageai.embeddings; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -36,18 +34,29 @@ public enum VoyageAIEmbeddingType { /** * This is a synonym for INT8 */ - BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8); + BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8), + /** + * Use this when you want to get back binary embeddings. Valid only for v3 models. + */ + BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT), + /** + * This is a synonym for BIT + */ + BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT); private static final class RequestConstants { private static final String FLOAT = "float"; private static final String INT8 = "int8"; + private static final String BIT = "binary"; } private static final Map ELEMENT_TYPE_TO_VOYAGE_EMBEDDING = Map.of( DenseVectorFieldMapper.ElementType.FLOAT, FLOAT, DenseVectorFieldMapper.ElementType.BYTE, - BYTE + BYTE, + DenseVectorFieldMapper.ElementType.BIT, + BIT ); static final EnumSet SUPPORTED_ELEMENT_TYPES = EnumSet.copyOf( ELEMENT_TYPE_TO_VOYAGE_EMBEDDING.keySet() From 8f6e03b33ec16bf6297596b1f1b564b79b037cc6 Mon Sep 17 00:00:00 2001 From: fzowl Date: Fri, 7 Feb 2025 18:17:40 +0100 Subject: [PATCH 09/20] Initial tests --- .../voyageai/VoyageAIEmbeddingsRequest.java | 16 +- .../VoyageAIEmbeddingsRequestEntity.java | 9 +- .../voyageai/VoyageAIRerankRequestEntity.java | 2 +- .../VoyageAIEmbeddingsResponseEntity.java | 83 ++++----- .../embeddings/VoyageAIEmbeddingType.java | 6 +- .../VoyageAIEmbeddingsServiceSettings.java | 3 +- .../VoyageAIEmbeddingsTaskSettings.java | 32 +--- .../rerank/VoyageAIRerankTaskSettings.java | 12 +- .../VoyageAIServiceSettingsTests.java | 1 - .../voyageai/VoyageAIServiceTests.java | 176 ++++++++++-------- .../VoyageAIEmbeddingsModelTests.java | 39 ++-- ...oyageAIEmbeddingsServiceSettingsTests.java | 11 +- .../VoyageAIEmbeddingsTaskSettingsTests.java | 62 ++++-- .../rerank/VoyageAIRerankModelTests.java | 28 +-- .../VoyageAIRerankServiceSettingsTests.java | 3 +- .../VoyageAIRerankTaskSettingsTests.java | 44 +++-- 16 files changed, 276 insertions(+), 251 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java index 7512bd723d5ad..ee01e25eafb33 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java @@ -49,12 +49,8 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new VoyageAIEmbeddingsRequestEntity( - input, - serviceSettings, - taskSettings, - model - )).getBytes(StandardCharsets.UTF_8) + Strings.toString(new VoyageAIEmbeddingsRequestEntity(input, serviceSettings, taskSettings, model)) + .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); @@ -83,9 +79,13 @@ public boolean[] getTruncationInfo() { return null; } - public VoyageAIEmbeddingsTaskSettings getTaskSettings() { return taskSettings; } + public VoyageAIEmbeddingsTaskSettings getTaskSettings() { + return taskSettings; + } - public VoyageAIEmbeddingsServiceSettings getServiceSettings() { return serviceSettings; } + public VoyageAIEmbeddingsServiceSettings getServiceSettings() { + return serviceSettings; + } public static URI buildDefaultUri() throws URISyntaxException { return new URIBuilder().setScheme("https") diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java index 9021148b5acd4..8191443edf75c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java @@ -53,16 +53,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(INPUT_TYPE_FIELD, inputType); } - if(taskSettings.getTruncation() != null) { + if (taskSettings.getTruncation() != null) { builder.field(TRUNCATION_FIELD, taskSettings.getTruncation()); } - if(serviceSettings.dimensions() != null) { + if (serviceSettings.dimensions() != null) { builder.field(OUTPUT_DIMENSION, serviceSettings.dimensions()); } - if(serviceSettings.getEmbeddingType() != null) { - builder.field(OUTPUT_DTYPE_FIELD, serviceSettings.getEmbeddingType()); + if (serviceSettings.getEmbeddingType() != null) { + builder.field(OUTPUT_DTYPE_FIELD, serviceSettings.getEmbeddingType().toRequestString()); } builder.endObject(); @@ -71,6 +71,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws static String convertToString(InputType inputType) { return switch (inputType) { + case null -> null; case INGEST -> DOCUMENT; case SEARCH -> QUERY; default -> { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java index eb6c7898e9e72..0f7baaa35044e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java @@ -52,7 +52,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, taskSettings.getTopKDocumentsOnly()); } - if(taskSettings.getTruncation() != null) { + if (taskSettings.getTruncation() != null) { builder.field(TRUNCATION_FIELD, taskSettings.getTruncation()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java index 58c3809f8975a..218ef932c9bf8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java @@ -10,7 +10,6 @@ package org.elasticsearch.xpack.inference.external.response.voyageai; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; -import org.elasticsearch.core.CheckedFunction; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; @@ -26,34 +25,27 @@ import org.elasticsearch.xpack.inference.external.response.XContentUtils; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; - import java.io.IOException; import java.util.Arrays; import java.util.List; -import java.util.Map; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType.toLowerCase; public class VoyageAIEmbeddingsResponseEntity { private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in VoyageAI embeddings response"; - private static final Map> EMBEDDING_PARSERS = Map.of( - toLowerCase(VoyageAIEmbeddingType.FLOAT), - VoyageAIEmbeddingsResponseEntity::parseFloatEmbeddingsArray, - toLowerCase(VoyageAIEmbeddingType.INT8), - VoyageAIEmbeddingsResponseEntity::parseByteEmbeddingsArray, - toLowerCase(VoyageAIEmbeddingType.BINARY), - VoyageAIEmbeddingsResponseEntity::parseBitEmbeddingsArray - ); private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes(); private static String supportedEmbeddingTypes() { - var validTypes = EMBEDDING_PARSERS.keySet().toArray(String[]::new); + String[] validTypes = new String[] { + toLowerCase(VoyageAIEmbeddingType.FLOAT), + toLowerCase(VoyageAIEmbeddingType.INT8), + toLowerCase(VoyageAIEmbeddingType.BIT) }; Arrays.sort(validTypes); return String.join(", ", validTypes); } @@ -105,7 +97,7 @@ private static String supportedEmbeddingTypes() { */ public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); - VoyageAIEmbeddingType embeddingType = ((VoyageAIEmbeddingsRequest)request).getServiceSettings().getEmbeddingType(); + VoyageAIEmbeddingType embeddingType = ((VoyageAIEmbeddingsRequest) request).getServiceSettings().getEmbeddingType(); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { moveToFirstToken(jsonParser); @@ -115,22 +107,31 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); - if(embeddingType == null || embeddingType == VoyageAIEmbeddingType.FLOAT) { + if (embeddingType == null || embeddingType == VoyageAIEmbeddingType.FLOAT) { List embeddingList = parseList( jsonParser, VoyageAIEmbeddingsResponseEntity::parseEmbeddingObjectFloat ); return new InferenceTextEmbeddingFloatResults(embeddingList); - } else if(embeddingType == VoyageAIEmbeddingType.INT8) { + } else if (embeddingType == VoyageAIEmbeddingType.INT8) { List embeddingList = parseList( jsonParser, VoyageAIEmbeddingsResponseEntity::parseEmbeddingObjectByte ); return new InferenceTextEmbeddingByteResults(embeddingList); + } else if (embeddingType == VoyageAIEmbeddingType.BIT || embeddingType == VoyageAIEmbeddingType.BINARY) { + List embeddingList = parseList( + jsonParser, + VoyageAIEmbeddingsResponseEntity::parseEmbeddingObjectBit + ); + + return new InferenceTextEmbeddingBitResults(embeddingList); } else { - throw new IllegalArgumentException("Illegal output_dtype value: " + embeddingType); + throw new IllegalArgumentException( + "Illegal embedding_type value: " + embeddingType + ". Supported types are: " + VALID_EMBEDDING_TYPES_STRING + ); } } } @@ -148,8 +149,7 @@ private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseE return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList); } - private static InferenceByteEmbedding parseEmbeddingObjectByte(XContentParser parser) - throws IOException { + private static InferenceByteEmbedding parseEmbeddingObjectByte(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -161,21 +161,14 @@ private static InferenceByteEmbedding parseEmbeddingObjectByte(XContentParser pa return InferenceByteEmbedding.of(embeddingValuesList); } - private static InferenceServiceResults parseBitEmbeddingsArray(XContentParser parser) throws IOException { - var embeddingList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseByteArrayEntry); - - return new InferenceTextEmbeddingBitResults(embeddingList); - } - - private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException { - var embeddingList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseByteArrayEntry); + private static InferenceByteEmbedding parseEmbeddingObjectBit(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - return new InferenceTextEmbeddingByteResults(embeddingList); - } + positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); - private static InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException { - ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - List embeddingValuesList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseEmbeddingInt8Entry); + List embeddingValuesList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseEmbeddingBitEntry); + // parse and discard the rest of the object + consumeUntilObjectEnd(parser); return InferenceByteEmbedding.of(embeddingValuesList); } @@ -189,24 +182,20 @@ private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOExce return (byte) parsedByte; } + private static Byte parseEmbeddingBitEntry(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); + var parsedBit = parser.shortValue(); + checkByteBounds(parsedBit); + + return (byte) parsedBit; + } + private static void checkByteBounds(short value) { if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte"); } } - private static InferenceServiceResults parseFloatEmbeddingsArray(XContentParser parser) throws IOException { - var embeddingList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseFloatArrayEntry); - - return new InferenceTextEmbeddingFloatResults(embeddingList); - } - - private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseFloatArrayEntry(XContentParser parser) - throws IOException { - ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); - return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList); - } - private VoyageAIEmbeddingsResponseEntity() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java index efa72ff2e7980..db13e46b14641 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java @@ -38,16 +38,16 @@ public enum VoyageAIEmbeddingType { /** * Use this when you want to get back binary embeddings. Valid only for v3 models. */ - BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT), + BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BINARY), /** * This is a synonym for BIT */ - BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT); + BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BINARY); private static final class RequestConstants { private static final String FLOAT = "float"; private static final String INT8 = "int8"; - private static final String BIT = "binary"; + private static final String BINARY = "binary"; } private static final Map ELEMENT_TYPE_TO_VOYAGE_EMBEDDING = Map.of( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java index f3b5b0a77f13a..909fd4bf99514 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java @@ -132,10 +132,10 @@ public VoyageAIEmbeddingsServiceSettings( public VoyageAIEmbeddingsServiceSettings(StreamInput in) throws IOException { this.commonSettings = new VoyageAIServiceSettings(in); - this.embeddingType = Objects.requireNonNullElse(in.readOptionalEnum(VoyageAIEmbeddingType.class), VoyageAIEmbeddingType.FLOAT); this.similarity = in.readOptionalEnum(SimilarityMeasure.class); this.dimensions = in.readOptionalVInt(); this.maxInputTokens = in.readOptionalVInt(); + this.embeddingType = Objects.requireNonNullElse(in.readOptionalEnum(VoyageAIEmbeddingType.class), VoyageAIEmbeddingType.FLOAT); } public VoyageAIServiceSettings getCommonSettings() { @@ -165,7 +165,6 @@ public VoyageAIEmbeddingType getEmbeddingType() { return embeddingType; } - @Override public DenseVectorFieldMapper.ElementType elementType() { return embeddingType == null ? DenseVectorFieldMapper.ElementType.FLOAT : embeddingType.toElementType(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java index b89c414e26520..d9e6076d53bcb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java @@ -27,8 +27,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; -import static org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields.OUTPUT_DIMENSION; import static org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields.TRUNCATION; /** @@ -43,10 +41,7 @@ public class VoyageAIEmbeddingsTaskSettings implements TaskSettings { public static final String NAME = "voyageai_embeddings_task_settings"; public static final VoyageAIEmbeddingsTaskSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsTaskSettings(null, null); static final String INPUT_TYPE = "input_type"; - static final EnumSet VALID_REQUEST_VALUES = EnumSet.of( - InputType.INGEST, - InputType.SEARCH - ); + static final EnumSet VALID_REQUEST_VALUES = EnumSet.of(InputType.INGEST, InputType.SEARCH); public static VoyageAIEmbeddingsTaskSettings fromMap(Map map) { if (map == null || map.isEmpty()) { @@ -63,17 +58,7 @@ public static VoyageAIEmbeddingsTaskSettings fromMap(Map map) { VALID_REQUEST_VALUES, validationException ); - Boolean truncation = extractOptionalBoolean( - map, - TRUNCATION, - validationException - ); - Integer outputDimension = extractOptionalPositiveInteger( - map, - OUTPUT_DIMENSION, - ModelConfigurations.TASK_SETTINGS, - validationException - ); + Boolean truncation = extractOptionalBoolean(map, TRUNCATION, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -132,16 +117,10 @@ private static Boolean getValidTruncation( private final Boolean truncation; public VoyageAIEmbeddingsTaskSettings(StreamInput in) throws IOException { - this( - in.readOptionalEnum(InputType.class), - in.readOptionalBoolean() - ); + this(in.readOptionalEnum(InputType.class), in.readOptionalBoolean()); } - public VoyageAIEmbeddingsTaskSettings( - @Nullable InputType inputType, - @Nullable Boolean truncation - ) { + public VoyageAIEmbeddingsTaskSettings(@Nullable InputType inputType, @Nullable Boolean truncation) { validateInputType(inputType); this.inputType = inputType; this.truncation = truncation; @@ -204,8 +183,7 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; VoyageAIEmbeddingsTaskSettings that = (VoyageAIEmbeddingsTaskSettings) o; - return Objects.equals(inputType, that.inputType) && - Objects.equals(truncation, that.truncation); + return Objects.equals(inputType, that.inputType) && Objects.equals(truncation, that.truncation); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java index 0aa65abfda3d2..1fc9505c13743 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java @@ -55,11 +55,7 @@ public static VoyageAIRerankTaskSettings fromMap(Map map) { validationException ); - Boolean truncation = extractOptionalBoolean( - map, - TRUNCATION, - validationException - ); + Boolean truncation = extractOptionalBoolean(map, TRUNCATION, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -156,9 +152,9 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; VoyageAIRerankTaskSettings that = (VoyageAIRerankTaskSettings) o; - return Objects.equals(topKDocumentsOnly, that.topKDocumentsOnly) && - Objects.equals(returnDocuments, that.returnDocuments) && - Objects.equals(truncation, that.truncation); + return Objects.equals(topKDocumentsOnly, that.topKDocumentsOnly) + && Objects.equals(returnDocuments, that.returnDocuments) + && Objects.equals(truncation, that.truncation); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java index 1fad1c6554da6..2ee0c45e2d93b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java @@ -18,7 +18,6 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceUtils; -import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.MatcherAssert; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 1907599f021ab..fee900220f00b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -39,7 +39,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettingsTests; @@ -110,7 +109,7 @@ public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModel() throws IOEx var embeddingsModel = (VoyageAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null, null))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); }, e -> fail("Model parsing should have succeeded " + e.getMessage())); @@ -136,7 +135,7 @@ public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSe var embeddingsModel = (VoyageAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null))); MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); @@ -165,7 +164,7 @@ public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSe var embeddingsModel = (VoyageAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null))); MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); @@ -358,7 +357,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModel( var embeddingsModel = (VoyageAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -384,7 +383,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelW var embeddingsModel = (VoyageAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } @@ -410,7 +409,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelW var embeddingsModel = (VoyageAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } @@ -460,7 +459,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelW var embeddingsModel = (VoyageAIEmbeddingsModel) model; assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri()); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null))); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } @@ -487,7 +486,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists var embeddingsModel = (VoyageAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -541,7 +540,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe var embeddingsModel = (VoyageAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -597,7 +596,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa var embeddingsModel = (VoyageAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); } } @@ -616,7 +615,7 @@ public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModel() throws IO var embeddingsModel = (VoyageAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); assertNull(embeddingsModel.getSecretSettings()); } } @@ -636,7 +635,7 @@ public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWhenChunking var embeddingsModel = (VoyageAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertNull(embeddingsModel.getSecretSettings()); } @@ -656,7 +655,7 @@ public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWhenChunking var embeddingsModel = (VoyageAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertNull(embeddingsModel.getSecretSettings()); } @@ -695,7 +694,7 @@ public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWithoutUrl() var embeddingsModel = (VoyageAIEmbeddingsModel) model; assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri()); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); assertNull(embeddingsModel.getSecretSettings()); } } @@ -737,7 +736,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettin var embeddingsModel = (VoyageAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null))); assertNull(embeddingsModel.getSecretSettings()); } } @@ -758,7 +757,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( var embeddingsModel = (VoyageAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null))); assertNull(embeddingsModel.getSecretSettings()); } } @@ -996,8 +995,9 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); - SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? - VoyageAIService.defaultSimilarity() : similarityMeasure; + SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null + ? VoyageAIService.defaultSimilarity() + : similarityMeasure; assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity()); assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); } @@ -1055,7 +1055,7 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); - var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "model", 1024, false); + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "model", 1024, false, false); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -1136,7 +1136,16 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "voyage-clip-v2", "task", "retrieval.passage"))); + MatcherAssert.assertThat( + requestMap, + is(Map.of( + "input", List.of("abc"), + "model", "voyage-clip-v2", + "input_type", "document", + "output_dtype", "float", + "output_dimension", 1024 + )) + ); } } @@ -1201,7 +1210,13 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "voyage-clip-v2", "task", "retrieval.query"))); + MatcherAssert.assertThat(requestMap, is(Map.of( + "input", List.of("abc"), + "model", "voyage-clip-v2", + "input_type", "query", + "output_dtype", "float", + "output_dimension", 1024 + ))); } } @@ -1250,7 +1265,12 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "voyage-clip-v2", "task", "separation"))); + MatcherAssert.assertThat(requestMap, is(Map.of( + "input", List.of("abc"), + "model", "voyage-clip-v2", + "output_dtype", "float", + "output_dimension", 1024 + ))); } } @@ -1306,7 +1326,12 @@ public void testInfer_Embedding_Get_Response_NullInputType() throws IOException MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "voyage-clip-v2"))); + MatcherAssert.assertThat(requestMap, is(Map.of( + "input", List.of("abc"), + "model", "voyage-clip-v2", + "output_dtype", "float", + "output_dimension", 1024 + ))); } } @@ -1314,7 +1339,8 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx String responseJson = """ { "model": "model", - "results": [ + "object": "list", + "data": [ { "index": 2, "relevance_score": 0.98005307 @@ -1337,7 +1363,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false); + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false, false); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -1385,6 +1411,8 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx "model", "model", "return_documents", + false, + "truncation", false ) ) @@ -1396,8 +1424,9 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOException { String responseJson = """ { + "object": "list", "model": "model", - "results": [ + "data": [ { "index": 2, "relevance_score": 0.98005307 @@ -1420,7 +1449,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false); + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false, false); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -1469,8 +1498,10 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce "model", "return_documents", false, - "top_n", - 3 + "top_k", + 3, + "truncation", + false ) ) ); @@ -1482,28 +1513,23 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IOException { String responseJson = """ { + "object": "list", "model": "model", - "results": [ + "data": [ { "index": 2, "relevance_score": 0.98005307, - "document": { - "text": "candidate3" - } + "document": "candidate3" }, { "index": 1, "relevance_score": 0.27904198, - "document": { - "text": "candidate2" - } + "document": "candidate2" }, { "index": 0, "relevance_score": 0.10194652, - "document": { - "text": "candidate1" - } + "document": "candidate1" } ], "usage": { @@ -1515,7 +1541,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null); + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null, null); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -1553,7 +1579,11 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO var requestMap = entityAsMap(webServer.requests().get(0).getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of("query", "query", "documents", List.of("candidate1", "candidate2", "candidate3"), "model", "model")) + is(Map.of( + "query", "query", + "documents", List.of("candidate1", "candidate2", "candidate3"), + "model", "model" + )) ); } @@ -1563,28 +1593,23 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOException { String responseJson = """ { + "object": "list", "model": "model", - "results": [ + "data": [ { "index": 2, "relevance_score": 0.98005307, - "document": { - "text": "candidate3" - } + "document": "candidate3" }, { "index": 1, "relevance_score": 0.27904198, - "document": { - "text": "candidate2" - } + "document": "candidate2" }, { "index": 0, "relevance_score": 0.10194652, - "document": { - "text": "candidate1" - } + "document": "candidate1" } ], "usage": { @@ -1596,7 +1621,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true); + var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true, true); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -1644,8 +1669,10 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept "model", "return_documents", true, - "top_n", - 3 + "top_k", + 3, + "truncation", + true ) ) ); @@ -1684,7 +1711,7 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings var model = VoyageAIEmbeddingsModelTests.createModel( getUrl(webServer), "secret", - new VoyageAIEmbeddingsTaskSettings((InputType) null, null, null), + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), 1024, 1024, "voyage-clip-v2", @@ -1714,7 +1741,13 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "voyage-clip-v2"))); + MatcherAssert.assertThat(requestMap, is(Map.of( + "input", List.of("abc"), + "model", "voyage-clip-v2", + "input_type", "document", + "output_dtype", "float", + "output_dimension", 1024 + ))); } } @@ -1722,7 +1755,7 @@ public void test_Embedding_ChunkedInfer_BatchesCallsChunkingSettingsSet() throws var model = VoyageAIEmbeddingsModelTests.createModel( getUrl(webServer), "secret", - new VoyageAIEmbeddingsTaskSettings((InputType) null), + new VoyageAIEmbeddingsTaskSettings((InputType) null, null), createRandomChunkingSettings(), 1024, 1024, @@ -1736,7 +1769,7 @@ public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOExcept var model = VoyageAIEmbeddingsModelTests.createModel( getUrl(webServer), "secret", - new VoyageAIEmbeddingsTaskSettings((InputType) null), + new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, 1024, 1024, @@ -1820,7 +1853,12 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("foo", "bar"), "model", "voyage-clip-v2"))); + MatcherAssert.assertThat(requestMap, is(Map.of( + "input", List.of("foo", "bar"), + "model", "voyage-clip-v2", + "output_dtype", "float", + "output_dimension", 1024 + ))); } } @@ -1847,24 +1885,6 @@ public void testGetConfiguration() throws Exception { "type": "str", "supported_task_types": ["text_embedding", "rerank"] }, - "dimensions": { - "description": "The number of dimensions the resulting embeddings should have. For more information refer to https://api.voyage.ai/redoc#tag/embeddings/operation/create_embedding_v1_embeddings_post.", - "label": "Dimensions", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding"] - }, - "model_id": { - "description": "The name of the model to use for the inference task.", - "label": "Model ID", - "required": true, - "sensitive": false, - "updatable": false, - "type": "str", - "supported_task_types": ["text_embedding", "rerank"] - }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", "label": "Rate Limit", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java index c11748a122b5a..fdee56bf1a9c5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java @@ -13,11 +13,8 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; -import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; -import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; import org.hamcrest.MatcherAssert; import java.util.Map; @@ -42,50 +39,50 @@ public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAre } public void testOverrideWith_SetsInputTypeToIngest_WhenTheFieldIsNullInModelTaskSettings_AndNullInRequestTaskSettings() { - var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model"); var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.INGEST); - var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model"); MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingStoredTaskSettings() { - var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model"); var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.SEARCH); - var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model"); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null), null, null, "model"); MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingRequestTaskSettings() { - var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model"); var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.INGEST), InputType.SEARCH); - var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model"); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null), null, null, "model"); MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_WhenRequestInputTypeIsInvalid() { - var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model"); var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.SEARCH), InputType.UNSPECIFIED); - var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model"); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null), null, null, "model"); MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } public void testOverrideWith_DoesNotSetInputType_FromRequest_IfInputTypeIsInvalid() { - var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model"); var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED); - var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model"); MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } public void testOverrideWith_DoesNotSetInputType_WhenRequestTaskSettingsIsNull_AndRequestInputTypeIsInvalid() { - var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model"); var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED); - var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model"); MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } @@ -117,6 +114,7 @@ public static VoyageAIEmbeddingsModel createModel( "service", new VoyageAIEmbeddingsServiceSettings( new VoyageAIServiceSettings(url, model, null), + VoyageAIEmbeddingType.FLOAT, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit @@ -140,6 +138,7 @@ public static VoyageAIEmbeddingsModel createModel( "service", new VoyageAIEmbeddingsServiceSettings( new VoyageAIServiceSettings(url, model, null), + VoyageAIEmbeddingType.FLOAT, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit @@ -162,7 +161,13 @@ public static VoyageAIEmbeddingsModel createModel( return new VoyageAIEmbeddingsModel( "id", "service", - new VoyageAIEmbeddingsServiceSettings(new VoyageAIServiceSettings(url, model, null), similarityMeasure, dimensions, tokenLimit), + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(url, model, null), + VoyageAIEmbeddingType.FLOAT, + similarityMeasure, + dimensions, + tokenLimit + ), taskSettings, null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java index 62c799fae1557..538b9b77d2069 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java @@ -22,10 +22,9 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; -import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.hamcrest.MatcherAssert; import java.io.IOException; @@ -46,7 +45,7 @@ public static VoyageAIEmbeddingsServiceSettings createRandom() { var commonSettings = VoyageAIServiceSettingsTests.createRandom(); - return new VoyageAIEmbeddingsServiceSettings(commonSettings, similarityMeasure, dims, maxInputTokens); + return new VoyageAIEmbeddingsServiceSettings(commonSettings, VoyageAIEmbeddingType.FLOAT, similarityMeasure, dims, maxInputTokens); } public void testFromMap() { @@ -78,6 +77,7 @@ public void testFromMap() { is( new VoyageAIEmbeddingsServiceSettings( new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, null), + VoyageAIEmbeddingType.FLOAT, SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens @@ -115,6 +115,7 @@ public void testFromMap_WithModelId() { is( new VoyageAIEmbeddingsServiceSettings( new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, null), + VoyageAIEmbeddingType.FLOAT, SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens @@ -142,9 +143,11 @@ public void testFromMap_InvalidSimilarity_ThrowsError() { ); } + @SuppressWarnings("checkstyle:LineLength") public void testToXContent_WritesAllValues() throws IOException { var serviceSettings = new VoyageAIEmbeddingsServiceSettings( new VoyageAIServiceSettings("url", "model", new RateLimitSettings(3)), + VoyageAIEmbeddingType.FLOAT, SimilarityMeasure.COSINE, 5, 10 @@ -155,7 +158,7 @@ public void testToXContent_WritesAllValues() throws IOException { String xContentResult = Strings.toString(builder); assertThat(xContentResult, is(""" {"url":"url","model_id":"model",""" + """ - "rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10}""")); + "rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10,"embedding_type":"float"}""")); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettingsTests.java index c7ed88cf51114..f3d85749e8e29 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettingsTests.java @@ -13,7 +13,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields; import org.hamcrest.MatcherAssert; import java.io.IOException; @@ -24,16 +24,17 @@ import java.util.Locale; import java.util.Map; -import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithoutUnspecified; +import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithIngestAndSearch; import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings.VALID_REQUEST_VALUES; import static org.hamcrest.Matchers.is; public class VoyageAIEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase { public static VoyageAIEmbeddingsTaskSettings createRandom() { - var inputType = randomBoolean() ? randomWithoutUnspecified() : null; + var inputType = randomBoolean() ? randomWithIngestAndSearch() : null; + var truncation = randomBoolean(); - return new VoyageAIEmbeddingsTaskSettings(inputType); + return new VoyageAIEmbeddingsTaskSettings(inputType, truncation); } public void testIsEmpty() { @@ -44,7 +45,7 @@ public void testIsEmpty() { public void testUpdatedTaskSettings_NotUpdated_UseInitialSettings() { var initialSettings = createRandom(); - var newSettings = new VoyageAIEmbeddingsTaskSettings((InputType) null); + var newSettings = new VoyageAIEmbeddingsTaskSettings((InputType) null, null); Map newSettingsMap = new HashMap<>(); VoyageAIEmbeddingsTaskSettings updatedSettings = (VoyageAIEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( Collections.unmodifiableMap(newSettingsMap) @@ -54,7 +55,7 @@ public void testUpdatedTaskSettings_NotUpdated_UseInitialSettings() { public void testUpdatedTaskSettings_Updated_UseNewSettings() { var initialSettings = createRandom(); - var newSettings = new VoyageAIEmbeddingsTaskSettings(randomWithoutUnspecified()); + var newSettings = new VoyageAIEmbeddingsTaskSettings(randomWithIngestAndSearch(), randomBoolean()); Map newSettingsMap = new HashMap<>(); newSettingsMap.put(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, newSettings.getInputType().toString()); VoyageAIEmbeddingsTaskSettings updatedSettings = (VoyageAIEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( @@ -66,27 +67,34 @@ public void testUpdatedTaskSettings_Updated_UseNewSettings() { public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() { MatcherAssert.assertThat( VoyageAIEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())), - is(new VoyageAIEmbeddingsTaskSettings((InputType) null)) + is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null)) ); } public void testFromMap_CreatesEmptySettings_WhenMapIsNull() { - MatcherAssert.assertThat(VoyageAIEmbeddingsTaskSettings.fromMap(null), is(new VoyageAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat( + VoyageAIEmbeddingsTaskSettings.fromMap(null), + is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null)) + ); } public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() { MatcherAssert.assertThat( VoyageAIEmbeddingsTaskSettings.fromMap( - new HashMap<>(Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString())) + new HashMap<>( + Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString(), VoyageAIServiceFields.TRUNCATION, false) + ) ), - is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST)) + is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, false)) ); } public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() { var exception = expectThrows( ValidationException.class, - () -> VoyageAIEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, "abc"))) + () -> VoyageAIEmbeddingsTaskSettings.fromMap( + new HashMap<>(Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, "abc", VoyageAIServiceFields.TRUNCATION, false)) + ) ); MatcherAssert.assertThat( @@ -100,6 +108,22 @@ public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() { ); } + public void testFromMap_ReturnsFailure_WhenTruncationIsInvalid() { + var exception = expectThrows( + ValidationException.class, + () -> VoyageAIEmbeddingsTaskSettings.fromMap( + new HashMap<>( + Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString(), VoyageAIServiceFields.TRUNCATION, "abc") + ) + ) + ); + + MatcherAssert.assertThat( + exception.getMessage(), + is("Validation Failed: 1: field [truncation] is not of the expected type. The value [abc] cannot be converted to a [Boolean];") + ); + } + public void testFromMap_ReturnsFailure_WhenInputTypeIsUnspecified() { var exception = expectThrows( ValidationException.class, @@ -127,12 +151,12 @@ private static > String getValidValuesSortedAndCombined(EnumSe } public void testXContent_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { - var thrownException = expectThrows(AssertionError.class, () -> new VoyageAIEmbeddingsTaskSettings(InputType.UNSPECIFIED)); + var thrownException = expectThrows(AssertionError.class, () -> new VoyageAIEmbeddingsTaskSettings(InputType.UNSPECIFIED, null)); MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]")); } public void testOf_KeepsOriginalValuesWhenRequestSettingsAreNull_AndRequestInputTypeIsInvalid() { - var taskSettings = new VoyageAIEmbeddingsTaskSettings(InputType.INGEST); + var taskSettings = new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, false); var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of( taskSettings, VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, @@ -142,25 +166,25 @@ public void testOf_KeepsOriginalValuesWhenRequestSettingsAreNull_AndRequestInput } public void testOf_UsesRequestTaskSettings() { - var taskSettings = new VoyageAIEmbeddingsTaskSettings((InputType) null); + var taskSettings = new VoyageAIEmbeddingsTaskSettings((InputType) null, null); var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of( taskSettings, - new VoyageAIEmbeddingsTaskSettings(InputType.INGEST), + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true), InputType.UNSPECIFIED ); - MatcherAssert.assertThat(overriddenTaskSettings, is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat(overriddenTaskSettings, is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true))); } public void testOf_UsesRequestTaskSettings_AndRequestInputType() { - var taskSettings = new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH); + var taskSettings = new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, true); var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of( taskSettings, - new VoyageAIEmbeddingsTaskSettings((InputType) null), + new VoyageAIEmbeddingsTaskSettings((InputType) null, null), InputType.INGEST ); - MatcherAssert.assertThat(overriddenTaskSettings, is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat(overriddenTaskSettings, is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true))); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java index a05d95cdbba42..0488e61a43ba3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java @@ -10,11 +10,8 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; -import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; -import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankServiceSettings; -import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; public class VoyageAIRerankModelTests { @@ -23,7 +20,7 @@ public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nu "id", "service", new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), - new VoyageAIRerankTaskSettings(topN, null), + new VoyageAIRerankTaskSettings(topN, null, null), new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } @@ -33,27 +30,33 @@ public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer "id", "service", new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), - new VoyageAIRerankTaskSettings(topN, null), + new VoyageAIRerankTaskSettings(topN, null, null), new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) ); } - public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topN, Boolean returnDocuments) { + public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topN, Boolean returnDocuments, Boolean truncation) { return new VoyageAIRerankModel( "id", "service", new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), - new VoyageAIRerankTaskSettings(topN, returnDocuments), + new VoyageAIRerankTaskSettings(topN, returnDocuments, truncation), new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) ); } - public static VoyageAIRerankModel createModel(String url, String modelId, @Nullable Integer topN, Boolean returnDocuments) { + public static VoyageAIRerankModel createModel( + String url, + String modelId, + @Nullable Integer topN, + Boolean returnDocuments, + Boolean truncation + ) { return new VoyageAIRerankModel( "id", "service", new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(url, modelId, null)), - new VoyageAIRerankTaskSettings(topN, returnDocuments), + new VoyageAIRerankTaskSettings(topN, returnDocuments, truncation), new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) ); } @@ -63,13 +66,14 @@ public static VoyageAIRerankModel createModel( String apiKey, String modelId, @Nullable Integer topN, - Boolean returnDocuments + Boolean returnDocuments, + Boolean truncation ) { return new VoyageAIRerankModel( "id", "service", new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(url, modelId, null)), - new VoyageAIRerankTaskSettings(topN, returnDocuments), + new VoyageAIRerankTaskSettings(topN, returnDocuments, truncation), new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java index 64f0f527c8a77..429be4a2c31d7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java @@ -15,10 +15,9 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankServiceSettings; -import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import java.io.IOException; import java.util.HashMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java index b471bdeae898f..85cc792c03244 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Nullable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; import java.io.IOException; import java.util.HashMap; @@ -24,28 +23,34 @@ public class VoyageAIRerankTaskSettingsTests extends AbstractWireSerializingTest public static VoyageAIRerankTaskSettings createRandom() { var returnDocuments = randomBoolean() ? randomBoolean() : null; var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null; + var truncation = randomBoolean() ? randomBoolean() : null; - return new VoyageAIRerankTaskSettings(topNDocsOnly, returnDocuments); + return new VoyageAIRerankTaskSettings(topNDocsOnly, returnDocuments, truncation); } public void testFromMap_WithValidValues_ReturnsSettings() { - Map taskMap = Map.of(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, true, VoyageAIRerankTaskSettings.TOP_N_DOCS_ONLY, 5); + Map taskMap = Map.of( + VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, + true, + VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, + 5 + ); var settings = VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap)); assertTrue(settings.getReturnDocuments()); - assertEquals(5, settings.getTopNDocumentsOnly().intValue()); + assertEquals(5, settings.getTopKDocumentsOnly().intValue()); } public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() { var settings = VoyageAIRerankTaskSettings.fromMap(Map.of()); assertNull(settings.getReturnDocuments()); - assertNull(settings.getTopNDocumentsOnly()); + assertNull(settings.getTopKDocumentsOnly()); } public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() { Map taskMap = Map.of( VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, "invalid", - VoyageAIRerankTaskSettings.TOP_N_DOCS_ONLY, + VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, 5 ); var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap))); @@ -56,46 +61,49 @@ public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { Map taskMap = Map.of( VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, true, - VoyageAIRerankTaskSettings.TOP_N_DOCS_ONLY, + VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, "invalid" ); var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap))); - assertThat(thrownException.getMessage(), containsString("field [top_n] is not of the expected type")); + assertThat(thrownException.getMessage(), containsString("Validation Failed: 1: field [top_k] is not of the expected type. The value [invalid] cannot be converted to a [Integer];")); } public void testUpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { - var initialSettings = new VoyageAIRerankTaskSettings(5, true); + var initialSettings = new VoyageAIRerankTaskSettings(5, true, true); VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of()); assertEquals(initialSettings, updatedSettings); } public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() { - var initialSettings = new VoyageAIRerankTaskSettings(5, true); + var initialSettings = new VoyageAIRerankTaskSettings(5, true, true); Map newSettings = Map.of(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, false); VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); assertFalse(updatedSettings.getReturnDocuments()); - assertEquals(initialSettings.getTopNDocumentsOnly(), updatedSettings.getTopNDocumentsOnly()); + assertTrue(updatedSettings.getTruncation()); + assertEquals(initialSettings.getTopKDocumentsOnly(), updatedSettings.getTopKDocumentsOnly()); } public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() { - var initialSettings = new VoyageAIRerankTaskSettings(5, true); - Map newSettings = Map.of(VoyageAIRerankTaskSettings.TOP_N_DOCS_ONLY, 7); + var initialSettings = new VoyageAIRerankTaskSettings(5, true, true); + Map newSettings = Map.of(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, 7); VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); - assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue()); + assertTrue(updatedSettings.getTruncation()); + assertEquals(7, updatedSettings.getTopKDocumentsOnly().intValue()); assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments()); } public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() { - var initialSettings = new VoyageAIRerankTaskSettings(5, true); + var initialSettings = new VoyageAIRerankTaskSettings(5, true, true); Map newSettings = Map.of( VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, false, - VoyageAIRerankTaskSettings.TOP_N_DOCS_ONLY, + VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, 7 ); VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertTrue(updatedSettings.getTruncation()); assertFalse(updatedSettings.getReturnDocuments()); - assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue()); + assertEquals(7, updatedSettings.getTopKDocumentsOnly().intValue()); } @Override @@ -121,7 +129,7 @@ public static Map getTaskSettingsMap(@Nullable Integer topNDocum var map = new HashMap(); if (topNDocumentsOnly != null) { - map.put(VoyageAIRerankTaskSettings.TOP_N_DOCS_ONLY, topNDocumentsOnly.toString()); + map.put(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, topNDocumentsOnly.toString()); } if (returnDocuments != null) { From d41538adb407c0f5d00a72385ee997d05a2c2903 Mon Sep 17 00:00:00 2001 From: fzowl Date: Sat, 8 Feb 2025 23:07:17 +0100 Subject: [PATCH 10/20] More tests --- .../VoyageAIEmbeddingsServiceSettings.java | 7 +- .../voyageai/VoyageAIActionCreatorTests.java | 276 ++++++++++ .../VoyageAIEmbeddingsActionTests.java | 364 +++++++++++++ .../VoyageAIEmbeddingsRequestEntityTests.java | 66 +++ .../VoyageAIEmbeddingsRequestTests.java | 130 +++++ .../voyageai/VoyageAIRequestTests.java | 38 ++ .../VoyageAIRerankRequestEntityTests.java | 145 +++++ .../voyageai/VoyageAIRerankRequestTests.java | 111 ++++ .../request/voyageai/VoyageAIUtilsTests.java | 24 + ...VoyageAIEmbeddingsResponseEntityTests.java | 506 ++++++++++++++++++ .../VoyageAIErrorResponseEntityTests.java | 52 ++ .../VoyageAIRerankResponseEntityTests.java | 175 ++++++ .../VoyageAIResponseHandlerTests.java | 139 +++++ 13 files changed, 2031 insertions(+), 2 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtilsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandlerTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java index 909fd4bf99514..579c140a7da24 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java @@ -38,6 +38,9 @@ public class VoyageAIEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { public static final String NAME = "voyageai_embeddings_service_settings"; + public static final VoyageAIEmbeddingsServiceSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsServiceSettings( + null, null, null, null, null + ); static final String EMBEDDING_TYPE = "embedding_type"; @@ -118,7 +121,7 @@ static VoyageAIEmbeddingType fromVoyageAIOrDenseVectorEnumValues(String enumStri public VoyageAIEmbeddingsServiceSettings( VoyageAIServiceSettings commonSettings, - VoyageAIEmbeddingType embeddingType, + @Nullable VoyageAIEmbeddingType embeddingType, @Nullable SimilarityMeasure similarity, @Nullable Integer dimensions, @Nullable Integer maxInputTokens @@ -127,7 +130,7 @@ public VoyageAIEmbeddingsServiceSettings( this.similarity = similarity; this.dimensions = dimensions; this.maxInputTokens = maxInputTokens; - this.embeddingType = Objects.requireNonNull(embeddingType); + this.embeddingType = embeddingType; } public VoyageAIEmbeddingsServiceSettings(StreamInput in) throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java new file mode 100644 index 0000000000000..1061c61ad9dc2 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java @@ -0,0 +1,276 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.voyageai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionCreator; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; +import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests; +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; + +public class VoyageAIActionCreatorTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testCreate_CohereEmbeddingsModel() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", + "texts": [ + "hello" + ], + "embeddings": { + "float": [ + [ + 0.123, + -0.123 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = CohereEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START), + 1024, + 1024, + "model", + CohereEmbeddingType.FLOAT + ); + var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, CohereTruncation.END); + var action = actionCreator.create(model, overriddenTaskSettings, InputType.UNSPECIFIED); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "texts", + List.of("abc"), + "model", + "model", + "input_type", + "search_query", + "embedding_types", + List.of("float"), + "truncate", + "end" + ) + ) + ); + } + } + + public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "response_id": "some id", + "text": "result", + "generation_id": "some id", + "chat_history": [ + { + "role": "USER", + "message": "input" + }, + { + "role": "CHATBOT", + "message": "result" + } + ], + "finish_reason": "COMPLETE", + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 4, + "output_tokens": 191 + }, + "tokens": { + "input_tokens": 70, + "output_tokens": 191 + } + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model"); + var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model, Map.of()); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap, is(Map.of("message", "abc", "model", "model"))); + } + } + + public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "response_id": "some id", + "text": "result", + "generation_id": "some id", + "chat_history": [ + { + "role": "USER", + "message": "input" + }, + { + "role": "CHATBOT", + "message": "result" + } + ], + "finish_reason": "COMPLETE", + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 4, + "output_tokens": 191 + }, + "tokens": { + "input_tokens": 70, + "output_tokens": 191 + } + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", null); + var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model, Map.of()); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap, is(Map.of("message", "abc"))); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java new file mode 100644 index 0000000000000..c560e3755ad3c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java @@ -0,0 +1,364 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.voyageai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.sender.CohereEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.cohere.CohereUtils; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationByte; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class VoyageAIEmbeddingsActionTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", + "texts": [ + "hello" + ], + "embeddings": { + "float": [ + [ + 0.123, + -0.123 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction( + getUrl(webServer), + "secret", + new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START), + "model", + CohereEmbeddingType.FLOAT, + sender + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(CohereUtils.REQUEST_SOURCE_HEADER), + equalTo(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "texts", + List.of("abc"), + "model", + "model", + "input_type", + "search_document", + "embedding_types", + List.of("float"), + "truncate", + "start" + ) + ) + ); + } + } + + public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", + "texts": [ + "hello" + ], + "embeddings": { + "int8": [ + [ + 0, + -1 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction( + getUrl(webServer), + "secret", + new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START), + "model", + CohereEmbeddingType.INT8, + sender + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationByte(List.of(new byte[] { 0, -1 })), result.asMap()); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(CohereUtils.REQUEST_SOURCE_HEADER), + equalTo(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "texts", + List.of("abc"), + "model", + "model", + "input_type", + "search_document", + "embedding_types", + List.of("int8"), + "truncate", + "start" + ) + ) + ); + } + } + + public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOException { + try (var sender = mock(Sender.class)) { + var thrownException = expectThrows( + IllegalArgumentException.class, + () -> createAction("^^", "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender) + ); + MatcherAssert.assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); + } + } + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is(format("Failed to send Cohere embeddings request to [%s]", getUrl(webServer))) + ); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled_WhenUrlIsNull() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send Cohere embeddings request")); + } + + public void testExecute_ThrowsException() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is(format("Failed to send Cohere embeddings request to [%s]", getUrl(webServer))) + ); + } + + public void testExecute_ThrowsExceptionWithNullUrl() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send Cohere embeddings request")); + } + + private ExecutableAction createAction( + String url, + String apiKey, + CohereEmbeddingsTaskSettings taskSettings, + @Nullable String modelName, + @Nullable CohereEmbeddingType embeddingType, + Sender sender + ) { + var model = CohereEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024, modelName, embeddingType); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( + model.getServiceSettings().getCommonSettings().uri(), + "Cohere embeddings" + ); + var requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool); + return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..9685913bd04e6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequestEntity; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class VoyageAIEmbeddingsRequestEntityTests extends ESTestCase { + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new VoyageAIEmbeddingsRequestEntity( + List.of("abc"), + VoyageAIEmbeddingsServiceSettings.EMPTY_SETTINGS, + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","input_type":"document"}""")); + } + + public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { + var entity = new VoyageAIEmbeddingsRequestEntity( + List.of("abc"), + VoyageAIEmbeddingsServiceSettings.EMPTY_SETTINGS, + VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model"}""")); + } + + public void testConvertToString_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { + var thrownException = expectThrows( + AssertionError.class, + () -> VoyageAIEmbeddingsRequestEntity.convertToString(InputType.UNSPECIFIED) + ); + MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..df3901325ff04 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java @@ -0,0 +1,130 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class VoyageAIEmbeddingsRequestTests extends ESTestCase { + public void testCreateRequest_UrlDefined() throws IOException { + var request = createRequest( + List.of("abc"), + VoyageAIEmbeddingsModelTests.createModel("url", "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, "model") + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap, is(Map.of( + "input", List.of("abc"), + "model", "model", + "output_dtype", "float" + ))); + } + + public void testCreateRequest_AllOptionsDefined() throws IOException { + var request = createRequest( + List.of("abc"), + VoyageAIEmbeddingsModelTests.createModel( + "url", + "secret", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + null, + null, + "model" + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap, is(Map.of( + "input", List.of("abc"), + "model", "model", + "input_type", "document", + "output_dtype", "float" + ))); + } + + public void testCreateRequest_InputTypeSearch() throws IOException { + var request = createRequest( + List.of("abc"), + VoyageAIEmbeddingsModelTests.createModel( + "url", + "secret", + new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null), + null, + null, + "model" + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap, is(Map.of( + "input", List.of("abc"), + "model", "model", + "input_type", "query", + "output_dtype", "float" + ))); + } + + public static VoyageAIEmbeddingsRequest createRequest(List input, VoyageAIEmbeddingsModel model) { + return new VoyageAIEmbeddingsRequest(input, model); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java new file mode 100644 index 0000000000000..1ca734ac6e24d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRequest; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils; + +import java.net.URI; + +import static org.hamcrest.Matchers.is; + +public class VoyageAIRequestTests extends ESTestCase { + + public void testDecorateWithAuthHeader() { + var request = new HttpPost("http://www.abc.com"); + + VoyageAIRequest.decorateWithAuthHeader( + request, + new VoyageAIAccount(URI.create("http://www.abc.com"), new SecureString(new char[] { 'a', 'b', 'c' })) + ); + + assertThat(request.getFirstHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(request.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer abc")); + assertThat(request.getFirstHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(), is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java new file mode 100644 index 0000000000000..801717e8d37bd --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java @@ -0,0 +1,145 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequestEntity; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class VoyageAIRerankRequestEntityTests extends ESTestCase { + public void testXContent_SingleRequest_WritesModelAndTopKIfDefined() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, null, null), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ], + "top_k": 8 + } + """)); + } + + public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsTrue() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, true, null), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ], + "return_documents": true, + "top_k": 8 + } + """)); + } + + public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsFalse() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, null), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ], + "return_documents": false, + "top_k": 8 + } + """)); + } + + public void testXContent_SingleRequest_DoesNotWriteTopKIfNull() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), null, "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ] + } + """)); + } + + public void testXContent_MultipleRequests_WritesModelAndTopKIfDefined() throws IOException { + var entity = new VoyageAIRerankRequestEntity( + "query", + List.of("abc", "def"), + new VoyageAIRerankTaskSettings(8, null, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc", + "def" + ], + "top_k": 8 + } + """)); + } + + public void testXContent_MultipleRequests_DoesNotWriteTopKIfNull() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc", "def"), null, "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc", + "def" + ] + } + """)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java new file mode 100644 index 0000000000000..416b11d39350d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java @@ -0,0 +1,111 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequest; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.sameInstance; + +public class VoyageAIRerankRequestTests extends ESTestCase { + + private static final String API_KEY = "foo"; + + public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException { + var input = "input"; + var query = "query"; + var modelId = "model"; + + var request = createRequest(query, input, modelId, null); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("documents"), is(List.of(input))); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("model"), is(modelId)); + } + + public void testCreateRequest_WithTopNSet() throws IOException { + var input = "input"; + var query = "query"; + var topK = 1; + var modelId = "model"; + + var request = createRequest(query, input, modelId, topK); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap, aMapWithSize(4)); + assertThat(requestMap.get("documents"), is(List.of(input))); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("top_k"), is(topK)); + assertThat(requestMap.get("model"), is(modelId)); + } + + public void testCreateRequest_WithModelSet() throws IOException { + var input = "input"; + var query = "query"; + var modelId = "model"; + + var request = createRequest(query, input, modelId, null); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("documents"), is(List.of(input))); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("model"), is(modelId)); + } + + public void testTruncate_DoesNotTruncate() { + var request = createRequest("query", "input", "null", null); + var truncatedRequest = request.truncate(); + + assertThat(truncatedRequest, sameInstance(request)); + } + + private static VoyageAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topN) { + var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, topN); + return new VoyageAIRerankRequest(query, List.of(input), rerankModel); + + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtilsTests.java new file mode 100644 index 0000000000000..186a2b0410111 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtilsTests.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.voyageai; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils; + +import static org.hamcrest.Matchers.is; + +public class VoyageAIUtilsTests extends ESTestCase { + + public void testCreateRequestSourceHeader() { + var requestSourceHeader = VoyageAIUtils.createRequestSourceHeader(); + + assertThat(requestSourceHeader.getName(), is("Request-Source")); + assertThat(requestSourceHeader.getValue(), is("unspecified:elasticsearch")); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java new file mode 100644 index 0000000000000..155f3be2dd904 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java @@ -0,0 +1,506 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.voyageai; + +import org.apache.http.HttpResponse; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIEmbeddingsResponseEntity; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests.createModel; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class VoyageAIEmbeddingsResponseEntityTests extends ESTestCase { + public void testFromResponse_CreatesResultsForASingleItem() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + } + ], + "model": "voyage-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel( + "url", + "api_key", + null, + "voyage-3-large" + ) + ); + + InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + ((InferenceTextEmbeddingFloatResults)parsedResults).embeddings(), + is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.014539449F, -0.015288644F }))) + ); + } + + public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + }, + { + "object": "embedding", + "index": 1, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "voyage-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel( + "url", + "api_key", + null, + "voyage-3-large" + ) + ); + + InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + ((InferenceTextEmbeddingFloatResults)parsedResults).embeddings(), + is( + List.of( + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.014539449F, -0.015288644F }), + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123F, -0.0123F }) + ) + ) + ); + } + + public void testFromResponse_FailsWhenDataFieldIsNotPresent() { + String responseJson = """ + { + "object": "list", + "not_data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + } + ], + "model": "voyage-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel( + "url", + "api_key", + null, + "voyage-3-large" + ) + ); + + var thrownException = expectThrows( + IllegalStateException.class, + () -> VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("Failed to find required field [data] in VoyageAI embeddings response")); + } + + public void testFromResponse_FailsWhenDataFieldNotAnArray() { + String responseJson = """ + { + "object": "list", + "data": { + "test": { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + } + }, + "model": "voyage-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel( + "url", + "api_key", + null, + "voyage-3-large" + ) + ); + + var thrownException = expectThrows( + ParsingException.class, + () -> VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]") + ); + } + + public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embeddingzzz": [ + 0.014539449, + -0.015288644 + ] + } + ], + "model": "voyage-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel( + "url", + "api_key", + null, + "voyage-3-large" + ) + ); + + var thrownException = expectThrows( + IllegalStateException.class, + () -> VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("Failed to find required field [embedding] in VoyageAI embeddings response")); + } + + public void testFromResponse_FailsWhenEmbeddingValueIsAString() { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + "abc" + ] + } + ], + "model": "voyage-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel( + "url", + "api_key", + null, + "voyage-3-large" + ) + ); + + var thrownException = expectThrows( + ParsingException.class, + () -> VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [VALUE_NUMBER] but found [VALUE_STRING]") + ); + } + + public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 1 + ] + } + ], + "model": "voyage-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel( + "url", + "api_key", + null, + "voyage-3-large" + ) + ); + + InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + ((InferenceTextEmbeddingFloatResults)parsedResults).embeddings(), + is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 1.0F }))) + ); + } + + public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 40294967295 + ] + } + ], + "model": "voyage-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel( + "url", + "api_key", + null, + "voyage-3-large" + ) + ); + + InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + ((InferenceTextEmbeddingFloatResults)parsedResults).embeddings(), + is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 4.0294965E10F }))) + ); + } + + public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + {} + ] + } + ], + "model": "voyage-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel( + "url", + "api_key", + null, + "voyage-3-large" + ) + ); + + var thrownException = expectThrows( + ParsingException.class, + () -> VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [VALUE_NUMBER] but found [START_OBJECT]") + ); + } + + public void testFieldsInDifferentOrderServer() throws IOException { + // The fields of the objects in the data array are reordered + String response = """ + { + "object": "list", + "id": "6667830b-716b-4796-9a61-33b67b5cc81d", + "model": "voyage-3-large", + "data": [ + { + "embedding": [ + -0.9, + 0.5, + 0.3 + ], + "index": 0, + "object": "embedding" + }, + { + "index": 0, + "embedding": [ + 0.1, + 0.5 + ], + "object": "embedding" + }, + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.5, + 0.5 + ] + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + }"""; + + VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( + List.of("abc", "def"), + createModel( + "url", + "api_key", + null, + "voyage-3-large" + ) + ); + + InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults, + instanceOf(InferenceTextEmbeddingFloatResults.class) + ); + + assertThat( + ((InferenceTextEmbeddingFloatResults)parsedResults).embeddings(), + is( + List.of( + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { -0.9F, 0.5F, 0.3F }), + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.1F, 0.5F }), + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.5F, 0.5F }) + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntityTests.java new file mode 100644 index 0000000000000..af78c81c3c10c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntityTests.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.voyageai; + +import org.apache.http.HttpResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIErrorResponseEntity; +import org.hamcrest.MatcherAssert; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class VoyageAIErrorResponseEntityTests extends ESTestCase { + public void testFromResponse() { + String message = "\"input\" length 2049 is larger than the largest allowed size 2048"; + String escapedMessage = message.replace("\\", "\\\\").replace("\"", "\\\""); + String responseJson = Strings.format(""" + { + "detail": "%s" + } + """, escapedMessage); + + ErrorResponse errorResponse = VoyageAIErrorResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + assertNotNull(errorResponse); + MatcherAssert.assertThat(errorResponse.getErrorMessage(), is(message)); + } + + public void testFromResponse_noMessage() { + String responseJson = """ + { + "error": "abc" + } + """; + + ErrorResponse errorResponse = VoyageAIErrorResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + MatcherAssert.assertThat(errorResponse, is(ErrorResponse.UNDEFINED_ERROR)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java new file mode 100644 index 0000000000000..4b6bd78b4e66c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java @@ -0,0 +1,175 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.voyageai; + +import org.apache.http.HttpResponse; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIRerankResponseEntity; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class VoyageAIRerankResponseEntityTests extends ESTestCase { + + public void testResponseLiteral() throws IOException { + InferenceServiceResults parsedResults = VoyageAIRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseLiteral.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + List expected = responseLiteralDocs(); + for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) { + assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index()); + } + } + + public void testGeneratedResponse() throws IOException { + int numDocs = randomIntBetween(1, 10); + + List expected = new ArrayList<>(numDocs); + StringBuilder responseBuilder = new StringBuilder(); + + responseBuilder.append("{"); + responseBuilder.append("\"model\": \"model\","); + responseBuilder.append("\"index\":\"").append(randomAlphaOfLength(36)).append("\","); + responseBuilder.append("\"data\": ["); + List indices = linear(numDocs); + List scores = linearFloats(numDocs); + for (int i = 0; i < numDocs; i++) { + int index = indices.remove(randomInt(indices.size() - 1)); + + responseBuilder.append("{"); + responseBuilder.append("\"index\":").append(index).append(","); + responseBuilder.append("\"relevance_score\":").append(scores.get(i).toString()).append("}"); + expected.add(new RankedDocsResults.RankedDoc(index, scores.get(i), null)); + if (i < numDocs - 1) { + responseBuilder.append(","); + } + } + responseBuilder.append("],"); + responseBuilder.append("\"usage\": {"); + responseBuilder.append("\"total_tokens\": 15}"); + responseBuilder.append("}"); + + InferenceServiceResults parsedResults = VoyageAIRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseBuilder.toString().getBytes(StandardCharsets.UTF_8)) + ); + MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) { + assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index()); + } + } + + private ArrayList responseLiteralDocs() { + var list = new ArrayList(); + + list.add(new RankedDocsResults.RankedDoc(2, 0.98005307F, null)); + list.add(new RankedDocsResults.RankedDoc(3, 0.27904198F, null)); + list.add(new RankedDocsResults.RankedDoc(0, 0.10194652F, null)); + return list; + + }; + + private final String responseLiteral = """ + { + "model": "model", + "data": [ + { + "index": 2, + "relevance_score": 0.98005307 + }, + { + "index": 3, + "relevance_score": 0.27904198 + }, + { + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + + public void testResponseLiteralWithDocuments() throws IOException { + InferenceServiceResults parsedResults = VoyageAIRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + MatcherAssert.assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText)); + } + + private final String responseLiteralWithDocuments = """ + { + "model": "model", + "data": [ + { + "document": "Washington, D.C..", + "index": 2, + "relevance_score": 0.98005307 + }, + { + "document": "Capital punishment has existed in the United States since beforethe United States was a country. ", + "index": 3, + "relevance_score": 0.27904198 + }, + { + "document": "Carson City is the capital city of the American state of Nevada.", + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + + private final List responseLiteralDocsWithText = List.of( + new RankedDocsResults.RankedDoc(2, 0.98005307F, "Washington, D.C.."), + new RankedDocsResults.RankedDoc( + 3, + 0.27904198F, + "Capital punishment has existed in the United States since beforethe United States was a country. " + ), + new RankedDocsResults.RankedDoc(0, 0.10194652F, "Carson City is the capital city of the American state of Nevada.") + ); + + private ArrayList linear(int n) { + ArrayList list = new ArrayList<>(); + for (int i = 0; i <= n; i++) { + list.add(i); + } + return list; + } + + // creates a list of doubles of monotonically decreasing magnitude + private ArrayList linearFloats(int n) { + ArrayList list = new ArrayList<>(); + float startValue = 1.0f; + float decrement = startValue / n + 1; + for (int i = 0; i <= n; i++) { + list.add(startValue - (i * decrement)); + } + return list; + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandlerTests.java new file mode 100644 index 0000000000000..0c45fa1b18429 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandlerTests.java @@ -0,0 +1,139 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.voyageai; + +import org.apache.http.Header; +import org.apache.http.HeaderElement; +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.hamcrest.MatcherAssert; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.core.Is.is; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class VoyageAIResponseHandlerTests extends ESTestCase { + public void testCheckForFailureStatusCode_DoesNotThrowForStatusCodesBetween200And299() { + callCheckForFailureStatusCode(randomIntBetween(200, 299), "id"); + } + + public void testCheckForFailureStatusCode_ThrowsFor503() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(503, "id")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [503]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor500_WithShouldRetryTrue() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(500, "id")); + assertTrue(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [500]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor429_WithShouldRetryTrue() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(429, "id")); + assertTrue(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a rate limit status code for request from inference entity id [id] status [429]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS)); + } + + public void testCheckForFailureStatusCode_ThrowsFor400() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(400, "id")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received an input validation error response for request from inference entity id [id] status [400]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor400_InputsTooLarge() { + var exception = expectThrows( + RetryException.class, + () -> callCheckForFailureStatusCode(400, "\"input\" length 2049 is larger than the largest allowed size 2048", "id") + ); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received an input validation error response for request from inference entity id [id] status [400]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor401() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(401, "inferenceEntityId")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString( + "Received an authentication error status code for request from inference entity id [inferenceEntityId] status [401]" + ) + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.UNAUTHORIZED)); + } + + public void testCheckForFailureStatusCode_ThrowsFor402() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(402, "inferenceEntityId")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat(exception.getCause().getMessage(), containsString("Payment required")); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.PAYMENT_REQUIRED)); + } + + private static void callCheckForFailureStatusCode(int statusCode, String modelId) { + callCheckForFailureStatusCode(statusCode, null, modelId); + } + + private static void callCheckForFailureStatusCode(int statusCode, @Nullable String errorMessage, String modelId) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + var header = mock(Header.class); + when(header.getElements()).thenReturn(new HeaderElement[] {}); + when(httpResponse.getFirstHeader(anyString())).thenReturn(header); + + String escapedErrorMessage = errorMessage != null ? errorMessage.replace("\\", "\\\\").replace("\"", "\\\"") : errorMessage; + + String responseJson = Strings.format(""" + { + "detail": "%s" + } + """, escapedErrorMessage); + + var mockRequest = mock(Request.class); + when(mockRequest.getInferenceEntityId()).thenReturn(modelId); + var httpResult = new HttpResult(httpResponse, errorMessage == null ? new byte[] {} : responseJson.getBytes(StandardCharsets.UTF_8)); + var handler = new VoyageAIResponseHandler("", (request, result) -> null); + + handler.checkForFailureStatusCode(mockRequest, httpResult); + } +} From 3f1a75a86199cc80dd5ac4705c21d2947831791f Mon Sep 17 00:00:00 2001 From: fzowl Date: Sun, 9 Feb 2025 13:58:28 +0100 Subject: [PATCH 11/20] More tests/corrections --- .../voyageai/VoyageAIActionCreatorTests.java | 192 +++------------ .../VoyageAIEmbeddingsActionTests.java | 221 +++++++++++------- .../results/TextEmbeddingResultsTests.java | 6 + .../voyageai/VoyageAIServiceTests.java | 14 +- .../VoyageAIEmbeddingsModelTests.java | 25 ++ 5 files changed, 199 insertions(+), 259 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java index 1061c61ad9dc2..2d2aaa0014eac 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java @@ -19,18 +19,15 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionCreator; +import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionCreator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; -import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettingsTests; import org.hamcrest.MatcherAssert; import org.junit.After; import org.junit.Before; @@ -45,7 +42,6 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; -import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.is; @@ -73,7 +69,7 @@ public void shutdown() throws IOException { webServer.close(); } - public void testCreate_CohereEmbeddingsModel() throws IOException { + public void testCreate_VoyageAIEmbeddingsModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { @@ -81,42 +77,34 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { String responseJson = """ { - "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], - "embeddings": { - "float": [ - [ + "object": "list", + "data": [{ + "object": "embedding", + "embedding": [ 0.123, -0.123 - ] - ] - }, - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 1 - } - }, - "response_type": "embeddings_by_type" + ], + "index": 0 + }], + "model": "voyage-3-large", + "usage": { + "total_tokens": 123 + } } """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = CohereEmbeddingsModelTests.createModel( + var model = VoyageAIEmbeddingsModelTests.createModel( getUrl(webServer), "secret", - new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START), + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true), 1024, 1024, "model", - CohereEmbeddingType.FLOAT + VoyageAIEmbeddingType.FLOAT ); - var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); - var overriddenTaskSettings = CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, CohereTruncation.END); + var actionCreator = new VoyageAIActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH); var action = actionCreator.create(model, overriddenTaskSettings, InputType.UNSPECIFIED); PlainActionFuture listener = new PlainActionFuture<>(); @@ -138,139 +126,15 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { requestMap, is( Map.of( - "texts", - List.of("abc"), - "model", - "model", - "input_type", - "search_query", - "embedding_types", - List.of("float"), - "truncate", - "end" + "output_dtype","float", + "truncation", true, + "input_type", "query", + "output_dimension",1024, + "input", List.of("abc"), + "model", "model" ) ) ); } } - - public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var sender = createSender(senderFactory)) { - sender.start(); - - String responseJson = """ - { - "response_id": "some id", - "text": "result", - "generation_id": "some id", - "chat_history": [ - { - "role": "USER", - "message": "input" - }, - { - "role": "CHATBOT", - "message": "result" - } - ], - "finish_reason": "COMPLETE", - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 4, - "output_tokens": 191 - }, - "tokens": { - "input_tokens": 70, - "output_tokens": 191 - } - } - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model"); - var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); - var action = actionCreator.create(model, Map.of()); - - PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - - var result = listener.actionGet(TIMEOUT); - - assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType())); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret")); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("message", "abc", "model", "model"))); - } - } - - public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var sender = createSender(senderFactory)) { - sender.start(); - - String responseJson = """ - { - "response_id": "some id", - "text": "result", - "generation_id": "some id", - "chat_history": [ - { - "role": "USER", - "message": "input" - }, - { - "role": "CHATBOT", - "message": "result" - } - ], - "finish_reason": "COMPLETE", - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 4, - "output_tokens": 191 - }, - "tokens": { - "input_tokens": 70, - "output_tokens": 191 - } - } - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", null); - var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); - var action = actionCreator.create(model, Map.of()); - - PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - - var result = listener.actionGet(TIMEOUT); - - assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType())); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret")); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("message", "abc"))); - } - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java index c560e3755ad3c..bcf08e7b735b2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java @@ -26,16 +26,15 @@ import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.sender.CohereEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.external.request.cohere.CohereUtils; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; import org.hamcrest.MatcherAssert; import org.junit.After; import org.junit.Before; @@ -51,6 +50,7 @@ import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationBinary; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationByte; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.hamcrest.Matchers.containsString; @@ -90,27 +90,19 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { String responseJson = """ { - "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], - "embeddings": { - "float": [ - [ + "object": "list", + "data": [{ + "object": "embedding", + "embedding": [ 0.123, -0.123 - ] - ] - }, - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 1 - } - }, - "response_type": "embeddings_by_type" + ], + "index": 0 + }], + "model": "voyage-3-large", + "usage": { + "total_tokens": 123 + } } """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); @@ -118,9 +110,9 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction( getUrl(webServer), "secret", - new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START), + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true), "model", - CohereEmbeddingType.FLOAT, + VoyageAIEmbeddingType.FLOAT, sender ); @@ -138,25 +130,21 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { ); MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(CohereUtils.REQUEST_SOURCE_HEADER), - equalTo(CohereUtils.ELASTIC_REQUEST_SOURCE) + webServer.requests().get(0).getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER), + equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) ); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); MatcherAssert.assertThat( requestMap, - is( + equalTo( Map.of( - "texts", - List.of("abc"), - "model", - "model", - "input_type", - "search_document", - "embedding_types", - List.of("float"), - "truncate", - "start" + "input", List.of("abc"), + "model", "model", + "input_type", "document", + "output_dtype", "float", + "truncation", true, + "output_dimension", 1024 ) ) ); @@ -171,27 +159,19 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I String responseJson = """ { - "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], - "embeddings": { - "int8": [ - [ + "object": "list", + "data": [{ + "object": "embedding", + "embedding": [ 0, -1 - ] - ] - }, - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 1 - } - }, - "response_type": "embeddings_by_type" + ], + "index": 0 + }], + "model": "voyage-3-large", + "usage": { + "total_tokens": 123 + } } """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); @@ -199,9 +179,9 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I var action = createAction( getUrl(webServer), "secret", - new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START), + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true), "model", - CohereEmbeddingType.INT8, + VoyageAIEmbeddingType.INT8, sender ); @@ -219,8 +199,77 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I ); MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(CohereUtils.REQUEST_SOURCE_HEADER), - equalTo(CohereUtils.ELASTIC_REQUEST_SOURCE) + webServer.requests().get(0).getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER), + equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "input", List.of("abc"), + "model", "model", + "input_type", "document", + "output_dtype", "int8", + "truncation", true, + "output_dimension", 1024 + ) + ) + ); + } + } + + public void testExecute_ReturnsSuccessfulResponse_ForBinaryResponseType() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data": [{ + "object": "embedding", + "embedding": [ + 0, + -1 + ], + "index": 0 + }], + "model": "voyage-3-large", + "usage": { + "total_tokens": 123 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction( + getUrl(webServer), + "secret", + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true), + "model", + VoyageAIEmbeddingType.BINARY, + sender + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationBinary(List.of(new byte[] { 0, -1 })), result.asMap()); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER), + equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) ); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); @@ -228,16 +277,12 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I requestMap, is( Map.of( - "texts", - List.of("abc"), - "model", - "model", - "input_type", - "search_document", - "embedding_types", - List.of("int8"), - "truncate", - "start" + "input", List.of("abc"), + "model", "model", + "input_type", "document", + "output_dtype", "binary", + "truncation", true, + "output_dimension", 1024 ) ) ); @@ -248,7 +293,7 @@ public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOExcept try (var sender = mock(Sender.class)) { var thrownException = expectThrows( IllegalArgumentException.class, - () -> createAction("^^", "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender) + () -> createAction("^^", "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender) ); MatcherAssert.assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); } @@ -258,7 +303,7 @@ public void testExecute_ThrowsElasticsearchException() { var sender = mock(Sender.class); doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); - var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); + var action = createAction(getUrl(webServer), "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); @@ -279,7 +324,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled return Void.TYPE; }).when(sender).send(any(), any(), any(), any()); - var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); + var action = createAction(getUrl(webServer), "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); @@ -288,7 +333,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled MatcherAssert.assertThat( thrownException.getMessage(), - is(format("Failed to send Cohere embeddings request to [%s]", getUrl(webServer))) + is(format("Failed to send VoyageAI embeddings request to [%s]", getUrl(webServer))) ); } @@ -303,21 +348,21 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled return Void.TYPE; }).when(sender).send(any(), any(), any(), any()); - var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); + var action = createAction(null, "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send Cohere embeddings request")); + MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send VoyageAI embeddings request")); } public void testExecute_ThrowsException() { var sender = mock(Sender.class); doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); - var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); + var action = createAction(getUrl(webServer), "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); @@ -326,7 +371,7 @@ public void testExecute_ThrowsException() { MatcherAssert.assertThat( thrownException.getMessage(), - is(format("Failed to send Cohere embeddings request to [%s]", getUrl(webServer))) + is(format("Failed to send VoyageAI embeddings request to [%s]", getUrl(webServer))) ); } @@ -334,30 +379,30 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var sender = mock(Sender.class); doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); - var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); + var action = createAction(null, "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender); PlainActionFuture listener = new PlainActionFuture<>(); action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send Cohere embeddings request")); + MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send VoyageAI embeddings request")); } private ExecutableAction createAction( String url, String apiKey, - CohereEmbeddingsTaskSettings taskSettings, + VoyageAIEmbeddingsTaskSettings taskSettings, @Nullable String modelName, - @Nullable CohereEmbeddingType embeddingType, + @Nullable VoyageAIEmbeddingType embeddingType, Sender sender ) { - var model = CohereEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024, modelName, embeddingType); + var model = VoyageAIEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024, modelName, embeddingType); var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( model.getServiceSettings().getCommonSettings().uri(), - "Cohere embeddings" + "VoyageAI embeddings" ); - var requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool); + var requestCreator = VoyageAIEmbeddingsRequestManager.of(model, threadPool); return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index 56bd690a9cdbf..b6b0bbefdf0a5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -146,4 +146,10 @@ public static Map buildExpectationByte(List embeddings) ); } + public static Map buildExpectationBinary(List embeddings) { + return Map.of( + "text_embedding_bits", + embeddings.stream().map(InferenceByteEmbedding::new).toList() + ); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index fee900220f00b..73bc3d9c32070 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -888,7 +888,7 @@ public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() th 10, 1, "voyage-clip-v2", - null + (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); service.checkModelConfig(model, listener); @@ -1022,7 +1022,7 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { 1024, 1024, "model", - null + (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1109,7 +1109,7 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { 1024, 1024, "voyage-clip-v2", - null + (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1183,7 +1183,7 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { 1024, 1024, "voyage-clip-v2", - null + (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1238,7 +1238,7 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { 1024, 1024, "voyage-clip-v2", - null + (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1308,7 +1308,7 @@ public void testInfer_Embedding_Get_Response_NullInputType() throws IOException 1024, 1024, "voyage-clip-v2", - null + (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener); @@ -1715,7 +1715,7 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings 1024, 1024, "voyage-clip-v2", - null + (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java index fdee56bf1a9c5..32a03a26f0323 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java @@ -149,6 +149,31 @@ public static VoyageAIEmbeddingsModel createModel( ); } + public static VoyageAIEmbeddingsModel createModel( + String url, + String apiKey, + VoyageAIEmbeddingsTaskSettings taskSettings, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model, + VoyageAIEmbeddingType embeddingType + ) { + return new VoyageAIEmbeddingsModel( + "id", + "service", + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(url, model, null), + embeddingType, + SimilarityMeasure.DOT_PRODUCT, + dimensions, + tokenLimit + ), + taskSettings, + null, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + public static VoyageAIEmbeddingsModel createModel( String url, String apiKey, From a89fdf1d484b2d49eb3b9857133ee9b3fc30fea3 Mon Sep 17 00:00:00 2001 From: fzowl Date: Sun, 9 Feb 2025 14:14:45 +0100 Subject: [PATCH 12/20] Removing warnings --- .../VoyageAIRerankResponseEntity.java | 1 - .../VoyageAIEmbeddingsTaskSettings.java | 1 - .../voyageai/VoyageAIActionCreatorTests.java | 9 +- .../VoyageAIEmbeddingsActionTests.java | 30 +++--- .../VoyageAIEmbeddingsRequestEntityTests.java | 1 - .../VoyageAIEmbeddingsRequestTests.java | 2 - .../voyageai/VoyageAIRequestTests.java | 2 - .../VoyageAIRerankRequestEntityTests.java | 1 - .../voyageai/VoyageAIRerankRequestTests.java | 1 - .../request/voyageai/VoyageAIUtilsTests.java | 1 - ...VoyageAIEmbeddingsResponseEntityTests.java | 1 - .../VoyageAIErrorResponseEntityTests.java | 1 - .../VoyageAIRerankResponseEntityTests.java | 98 +++++++++---------- .../voyageai/VoyageAIServiceTests.java | 72 +++++++------- ...oyageAIEmbeddingsServiceSettingsTests.java | 3 +- .../VoyageAIRerankTaskSettingsTests.java | 5 +- 16 files changed, 107 insertions(+), 122 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java index 2d7b5d9846342..2a1c1d868b248 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java @@ -35,7 +35,6 @@ public class VoyageAIRerankResponseEntity { /** * Parses the VoyageAI ranked response. - * * For a request like: * "model": "rerank-2", * "query": "What is the capital of the United States?", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java index d9e6076d53bcb..5d8d282588349 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java @@ -71,7 +71,6 @@ public static VoyageAIEmbeddingsTaskSettings fromMap(Map map) { * Creates a new {@link VoyageAIEmbeddingsTaskSettings} by preferring non-null fields from the provided parameters. * For the input type, preference is given to requestInputType if it is not null and not UNSPECIFIED. * Then preference is given to the requestTaskSettings and finally to originalSettings even if the value is null. - * * Similarly, for the truncation field preference is given to requestTaskSettings if it is not null and then to * originalSettings. * @param originalSettings the settings stored as part of the inference entity configuration diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java index 2d2aaa0014eac..faeb8d477f6f4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java @@ -19,7 +19,6 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionCreator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -114,14 +113,14 @@ public void testCreate_VoyageAIEmbeddingsModel() throws IOException { MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( requestMap, is( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java index bcf08e7b735b2..50237bde23a37 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java @@ -123,18 +123,18 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER), + webServer.requests().getFirst().getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER), equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) ); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( requestMap, equalTo( @@ -192,18 +192,18 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I assertEquals(buildExpectationByte(List.of(new byte[] { 0, -1 })), result.asMap()); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER), + webServer.requests().getFirst().getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER), equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) ); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( requestMap, is( @@ -261,18 +261,18 @@ public void testExecute_ReturnsSuccessfulResponse_ForBinaryResponseType() throws assertEquals(buildExpectationBinary(List.of(new byte[] { 0, -1 })), result.asMap()); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER), + webServer.requests().getFirst().getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER), equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) ); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( requestMap, is( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java index 9685913bd04e6..294fcc559375e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequestEntity; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; import org.hamcrest.MatcherAssert; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java index df3901325ff04..fc3cab89b5d18 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java @@ -12,8 +12,6 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest; -import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java index 1ca734ac6e24d..72216d6382968 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java @@ -13,8 +13,6 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount; -import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRequest; -import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils; import java.net.URI; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java index 801717e8d37bd..16a3e56fb1013 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequestEntity; import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; import java.io.IOException; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java index 416b11d39350d..c6aaa933d0303 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequest; import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModelTests; import java.io.IOException; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtilsTests.java index 186a2b0410111..4a8271f4fac88 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIUtilsTests.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.external.request.voyageai; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils; import static org.hamcrest.Matchers.is; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java index 155f3be2dd904..993c0e09145f7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest; -import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIEmbeddingsResponseEntity; import java.io.IOException; import java.nio.charset.StandardCharsets; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntityTests.java index af78c81c3c10c..bb5c1da90d776 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIErrorResponseEntityTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; -import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIErrorResponseEntity; import org.hamcrest.MatcherAssert; import java.nio.charset.StandardCharsets; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java index 4b6bd78b4e66c..1e7a89e5a406e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIRerankResponseEntity; import org.hamcrest.MatcherAssert; import java.io.IOException; @@ -27,6 +26,28 @@ public class VoyageAIRerankResponseEntityTests extends ESTestCase { public void testResponseLiteral() throws IOException { + String responseLiteral = """ + { + "model": "model", + "data": [ + { + "index": 2, + "relevance_score": 0.98005307 + }, + { + "index": 3, + "relevance_score": 0.27904198 + }, + { + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """; InferenceServiceResults parsedResults = VoyageAIRerankResponseEntity.fromResponse( new HttpResult(mock(HttpResponse.class), responseLiteral.getBytes(StandardCharsets.UTF_8)) ); @@ -82,33 +103,34 @@ private ArrayList responseLiteralDocs() { list.add(new RankedDocsResults.RankedDoc(3, 0.27904198F, null)); list.add(new RankedDocsResults.RankedDoc(0, 0.10194652F, null)); return list; + } - }; - - private final String responseLiteral = """ - { - "model": "model", - "data": [ - { - "index": 2, - "relevance_score": 0.98005307 - }, - { - "index": 3, - "relevance_score": 0.27904198 - }, - { - "index": 0, - "relevance_score": 0.10194652 + public void testResponseLiteralWithDocuments() throws IOException { + String responseLiteralWithDocuments = """ + { + "model": "model", + "data": [ + { + "document": "Washington, D.C..", + "index": 2, + "relevance_score": 0.98005307 + }, + { + "document": "Capital punishment has existed in the United States since beforethe United States was a country. ", + "index": 3, + "relevance_score": 0.27904198 + }, + { + "document": "Carson City is the capital city of the American state of Nevada.", + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 } - ], - "usage": { - "total_tokens": 15 } - } - """; - - public void testResponseLiteralWithDocuments() throws IOException { + """; InferenceServiceResults parsedResults = VoyageAIRerankResponseEntity.fromResponse( new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8)) ); @@ -117,32 +139,6 @@ public void testResponseLiteralWithDocuments() throws IOException { MatcherAssert.assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText)); } - private final String responseLiteralWithDocuments = """ - { - "model": "model", - "data": [ - { - "document": "Washington, D.C..", - "index": 2, - "relevance_score": 0.98005307 - }, - { - "document": "Capital punishment has existed in the United States since beforethe United States was a country. ", - "index": 3, - "relevance_score": 0.27904198 - }, - { - "document": "Carson City is the capital city of the American state of Nevada.", - "index": 0, - "relevance_score": 0.10194652 - } - ], - "usage": { - "total_tokens": 15 - } - } - """; - private final List responseLiteralDocsWithText = List.of( new RankedDocsResults.RankedDoc(2, 0.98005307F, "Washington, D.C.."), new RankedDocsResults.RankedDoc( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 73bc3d9c32070..0e37c8cddc6e3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -1128,14 +1128,14 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( requestMap, is(Map.of( @@ -1202,14 +1202,14 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat(requestMap, is(Map.of( "input", List.of("abc"), "model", "voyage-clip-v2", @@ -1257,14 +1257,14 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat(requestMap, is(Map.of( "input", List.of("abc"), "model", "voyage-clip-v2", @@ -1318,14 +1318,14 @@ public void testInfer_Embedding_Get_Response_NullInputType() throws IOException assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat(requestMap, is(Map.of( "input", List.of("abc"), "model", "voyage-clip-v2", @@ -1394,12 +1394,12 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx MatcherAssert.assertThat(webServer.requests(), hasSize(1)); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( requestMap, is( @@ -1480,12 +1480,12 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce MatcherAssert.assertThat(webServer.requests(), hasSize(1)); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( requestMap, is( @@ -1571,12 +1571,12 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO ); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( requestMap, is(Map.of( @@ -1651,12 +1651,12 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept ); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( requestMap, is( @@ -1733,14 +1733,14 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat(requestMap, is(Map.of( "input", List.of("abc"), "model", "voyage-clip-v2", @@ -1845,14 +1845,14 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo } MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat(requestMap, is(Map.of( "input", List.of("foo", "bar"), "model", "voyage-clip-v2", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java index 538b9b77d2069..60eb46f016124 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java @@ -185,7 +185,6 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { } public static Map getServiceSettingsMap(@Nullable String url, String model) { - var map = new HashMap<>(VoyageAIServiceSettingsTests.getServiceSettingsMap(url, model)); - return map; + return new HashMap<>(VoyageAIServiceSettingsTests.getServiceSettingsMap(url, model)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java index 85cc792c03244..c2c58464719e2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java @@ -65,7 +65,10 @@ public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { "invalid" ); var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap))); - assertThat(thrownException.getMessage(), containsString("Validation Failed: 1: field [top_k] is not of the expected type. The value [invalid] cannot be converted to a [Integer];")); + assertThat( + thrownException.getMessage(), + containsString("field [top_k] is not of the expected type. The value [invalid] cannot be converted to a [Integer];" + )); } public void testUpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { From b7cb871af98478368a84cf218e9da72d253b335f Mon Sep 17 00:00:00 2001 From: fzowl Date: Sun, 9 Feb 2025 14:53:48 +0100 Subject: [PATCH 13/20] Further tests --- .../VoyageAIEmbeddingsServiceSettings.java | 2 +- .../VoyageAIEmbeddingsRequestEntityTests.java | 108 ++++++++++++++++++ .../VoyageAIEmbeddingsRequestTests.java | 72 ++++++++++++ .../VoyageAIRerankRequestEntityTests.java | 42 +++++++ .../voyageai/VoyageAIRerankRequestTests.java | 4 +- .../voyageai/VoyageAIServiceTests.java | 12 +- .../rerank/VoyageAIRerankModelTests.java | 29 +++-- .../VoyageAIRerankTaskSettingsTests.java | 20 +++- 8 files changed, 269 insertions(+), 20 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java index 579c140a7da24..e0ab27e0d5c24 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java @@ -42,7 +42,7 @@ public class VoyageAIEmbeddingsServiceSettings extends FilteredXContentObject im null, null, null, null, null ); - static final String EMBEDDING_TYPE = "embedding_type"; + public static final String EMBEDDING_TYPE = "embedding_type"; public static VoyageAIEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java index 294fcc559375e..593cc803c4778 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java @@ -9,20 +9,128 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; import org.hamcrest.MatcherAssert; import java.io.IOException; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.hamcrest.CoreMatchers.is; public class VoyageAIEmbeddingsRequestEntityTests extends ESTestCase { + public void testXContent_WritesAllFields_ServiceSettingsDefined() throws IOException { + var entity = new VoyageAIEmbeddingsRequestEntity( + List.of("abc"), + VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + "https://www.abc.com", + ServiceFields.SIMILARITY, + SimilarityMeasure.DOT_PRODUCT.toString(), + ServiceFields.DIMENSIONS, + 2048, + ServiceFields.MAX_INPUT_TOKENS, + 512, + VoyageAIServiceSettings.MODEL_ID, + "model", + VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE, + "float" + ) + ), + ConfigurationParseContext.PERSISTENT + ), + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"float"}""")); + } + + public void testXContent_WritesAllFields_ServiceSettingsDefined_Int8() throws IOException { + var entity = new VoyageAIEmbeddingsRequestEntity( + List.of("abc"), + VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + "https://www.abc.com", + ServiceFields.SIMILARITY, + SimilarityMeasure.DOT_PRODUCT.toString(), + ServiceFields.DIMENSIONS, + 2048, + ServiceFields.MAX_INPUT_TOKENS, + 512, + VoyageAIServiceSettings.MODEL_ID, + "model", + VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE, + "int8" + ) + ), + ConfigurationParseContext.PERSISTENT + ), + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"int8"}""")); + } + + public void testXContent_WritesAllFields_ServiceSettingsDefined_Binary() throws IOException { + var entity = new VoyageAIEmbeddingsRequestEntity( + List.of("abc"), + VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + "https://www.abc.com", + ServiceFields.SIMILARITY, + SimilarityMeasure.DOT_PRODUCT.toString(), + ServiceFields.DIMENSIONS, + 2048, + ServiceFields.MAX_INPUT_TOKENS, + 512, + VoyageAIServiceSettings.MODEL_ID, + "model", + VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE, + "binary" + ) + ), + ConfigurationParseContext.PERSISTENT + ), + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"binary"}""")); + } + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { var entity = new VoyageAIEmbeddingsRequestEntity( List.of("abc"), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java index fc3cab89b5d18..1c9abb45a11ee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; @@ -87,6 +88,77 @@ public void testCreateRequest_AllOptionsDefined() throws IOException { ))); } + public void testCreateRequest_DimensionDefined() throws IOException { + var request = createRequest( + List.of("abc"), + VoyageAIEmbeddingsModelTests.createModel( + "url", + "secret", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + null, + 2048, + "model" + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap, is(Map.of( + "input", List.of("abc"), + "model", "model", + "input_type", "document", + "output_dtype", "float", + "output_dimension", 2048 + ))); + } + + public void testCreateRequest_EmbeddingTypeDefined() throws IOException { + var request = createRequest( + List.of("abc"), + VoyageAIEmbeddingsModelTests.createModel( + "url", + "secret", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + null, + 2048, + "model", + VoyageAIEmbeddingType.BYTE + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap, is(Map.of( + "input", List.of("abc"), + "model", "model", + "input_type", "document", + "output_dtype", "int8", + "output_dimension", 2048 + ))); + } + public void testCreateRequest_InputTypeSearch() throws IOException { var request = createRequest( List.of("abc"), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java index 16a3e56fb1013..ae431b4b7bb13 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java @@ -79,6 +79,48 @@ public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumen """)); } + public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationTrue() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, true), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ], + "return_documents": false, + "top_k": 8, + "truncation": true + } + """)); + } + + public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationFalse() throws IOException { + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, false), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ], + "return_documents": false, + "top_k": 8, + "truncation": false + } + """)); + } + public void testXContent_SingleRequest_DoesNotWriteTopKIfNull() throws IOException { var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), null, "model"); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java index c6aaa933d0303..0bffd9c7dc268 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java @@ -102,8 +102,8 @@ public void testTruncate_DoesNotTruncate() { assertThat(truncatedRequest, sameInstance(request)); } - private static VoyageAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topN) { - var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, topN); + private static VoyageAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topK) { + var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, topK); return new VoyageAIRerankRequest(query, List.of(input), rerankModel); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 0e37c8cddc6e3..d995bff1e9caa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -1830,18 +1830,18 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); - var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); + assertThat(results.getFirst(), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.getFirst(); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals("foo", floatResult.chunks().get(0).matchedText()); - assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f); + assertEquals("foo", floatResult.chunks().getFirst().matchedText()); + assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().getFirst().embedding(), 0.0f); } { assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals("bar", floatResult.chunks().get(0).matchedText()); - assertArrayEquals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding(), 0.0f); + assertEquals("bar", floatResult.chunks().getFirst().matchedText()); + assertArrayEquals(new float[] { 0.223f, -0.223f }, floatResult.chunks().getFirst().embedding(), 0.0f); } MatcherAssert.assertThat(webServer.requests(), hasSize(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java index 0488e61a43ba3..bca6cc94d9afd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java @@ -14,33 +14,42 @@ import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; public class VoyageAIRerankModelTests { + public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nullable Integer topK, @Nullable Boolean truncation) { + return new VoyageAIRerankModel( + "id", + "service", + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), + new VoyageAIRerankTaskSettings(topK, null, truncation), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } - public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nullable Integer topN) { + public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nullable Integer topK) { return new VoyageAIRerankModel( "id", "service", new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), - new VoyageAIRerankTaskSettings(topN, null, null), + new VoyageAIRerankTaskSettings(topK, null, null), new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } - public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topN) { + public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topK) { return new VoyageAIRerankModel( "id", "service", new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), - new VoyageAIRerankTaskSettings(topN, null, null), + new VoyageAIRerankTaskSettings(topK, null, null), new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) ); } - public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topN, Boolean returnDocuments, Boolean truncation) { + public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topK, Boolean returnDocuments, Boolean truncation) { return new VoyageAIRerankModel( "id", "service", new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), - new VoyageAIRerankTaskSettings(topN, returnDocuments, truncation), + new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation), new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) ); } @@ -48,7 +57,7 @@ public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer public static VoyageAIRerankModel createModel( String url, String modelId, - @Nullable Integer topN, + @Nullable Integer topK, Boolean returnDocuments, Boolean truncation ) { @@ -56,7 +65,7 @@ public static VoyageAIRerankModel createModel( "id", "service", new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(url, modelId, null)), - new VoyageAIRerankTaskSettings(topN, returnDocuments, truncation), + new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation), new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) ); } @@ -65,7 +74,7 @@ public static VoyageAIRerankModel createModel( String url, String apiKey, String modelId, - @Nullable Integer topN, + @Nullable Integer topK, Boolean returnDocuments, Boolean truncation ) { @@ -73,7 +82,7 @@ public static VoyageAIRerankModel createModel( "id", "service", new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(url, modelId, null)), - new VoyageAIRerankTaskSettings(topN, returnDocuments, truncation), + new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation), new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java index c2c58464719e2..6823bfb9d21de 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Nullable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields; import java.io.IOException; import java.util.HashMap; @@ -28,22 +29,39 @@ public static VoyageAIRerankTaskSettings createRandom() { return new VoyageAIRerankTaskSettings(topNDocsOnly, returnDocuments, truncation); } + public void testFromMap_WithInvalidTruncation_ThrowsValidationException() { + Map taskMap = Map.of( + VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, + true, + VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, + 5, + VoyageAIServiceFields.TRUNCATION, + "invalid" + ); + var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [truncation] is not of the expected type")); + } + public void testFromMap_WithValidValues_ReturnsSettings() { Map taskMap = Map.of( VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, true, VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, - 5 + 5, + VoyageAIServiceFields.TRUNCATION, + true ); var settings = VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap)); assertTrue(settings.getReturnDocuments()); assertEquals(5, settings.getTopKDocumentsOnly().intValue()); + assertTrue(settings.getTruncation()); } public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() { var settings = VoyageAIRerankTaskSettings.fromMap(Map.of()); assertNull(settings.getReturnDocuments()); assertNull(settings.getTopKDocumentsOnly()); + assertNull(settings.getTruncation()); } public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() { From b9681af5dcf3dc8af8981f5c14374dd9b04bb543 Mon Sep 17 00:00:00 2001 From: fzowl Date: Sun, 9 Feb 2025 15:05:11 +0100 Subject: [PATCH 14/20] Transport version correction --- server/src/main/java/org/elasticsearch/TransportVersions.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index f822f57f78f3d..db4676ec30b68 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -180,7 +180,7 @@ static TransportVersion def(int id) { public static final TransportVersion REMOVE_DESIRED_NODE_VERSION = def(9_004_0_00); - public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_002_0_00); + public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_005_0_00); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ From c2115833bc01a77f14026cc5ad4d0218813b903f Mon Sep 17 00:00:00 2001 From: fzowl Date: Sun, 9 Feb 2025 15:17:57 +0100 Subject: [PATCH 15/20] Adding changelog and correcting TransportVersions --- docs/changelog/122134.yaml | 5 +++++ .../src/main/java/org/elasticsearch/TransportVersions.java | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 docs/changelog/122134.yaml diff --git a/docs/changelog/122134.yaml b/docs/changelog/122134.yaml new file mode 100644 index 0000000000000..25ca556789525 --- /dev/null +++ b/docs/changelog/122134.yaml @@ -0,0 +1,5 @@ +pr: 122134 +summary: Adding integration for VoyageAI embeddings and rerank models +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 5c129dd33ee61..96f21e2e6369c 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -186,7 +186,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE = def(9_006_0_00); public static final TransportVersion ESQL_PROFILE_ASYNC_NANOS = def(9_007_00_0); - public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_005_0_00); + public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_008_0_00); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ From 91de4c3a5724e60c42f56ce13bdd52eb0ef67798 Mon Sep 17 00:00:00 2001 From: fzowl Date: Tue, 11 Feb 2025 02:21:19 +0100 Subject: [PATCH 16/20] Spotless tests --- .../VoyageAIEmbeddingsServiceSettings.java | 6 +- .../voyageai/VoyageAIActionCreatorTests.java | 22 ++- .../VoyageAIEmbeddingsActionTests.java | 56 ++++--- .../VoyageAIEmbeddingsRequestEntityTests.java | 2 +- .../VoyageAIEmbeddingsRequestTests.java | 83 +++++---- .../voyageai/VoyageAIRerankRequestTests.java | 4 +- ...VoyageAIEmbeddingsResponseEntityTests.java | 85 ++-------- .../VoyageAIResponseHandlerTests.java | 1 - .../results/TextEmbeddingResultsTests.java | 5 +- .../voyageai/VoyageAIServiceTests.java | 157 ++++++++++-------- ...oyageAIEmbeddingsServiceSettingsTests.java | 12 +- .../VoyageAIRerankTaskSettingsTests.java | 4 +- 12 files changed, 222 insertions(+), 215 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java index e0ab27e0d5c24..ef92cac0010b0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java @@ -39,7 +39,11 @@ public class VoyageAIEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { public static final String NAME = "voyageai_embeddings_service_settings"; public static final VoyageAIEmbeddingsServiceSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsServiceSettings( - null, null, null, null, null + null, + null, + null, + null, + null ); public static final String EMBEDDING_TYPE = "embedding_type"; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java index faeb8d477f6f4..50c64468d732a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java @@ -43,9 +43,9 @@ import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; public class VoyageAIActionCreatorTests extends ESTestCase { @@ -125,12 +125,18 @@ public void testCreate_VoyageAIEmbeddingsModel() throws IOException { requestMap, is( Map.of( - "output_dtype","float", - "truncation", true, - "input_type", "query", - "output_dimension",1024, - "input", List.of("abc"), - "model", "model" + "output_dtype", + "float", + "truncation", + true, + "input_type", + "query", + "output_dimension", + 1024, + "input", + List.of("abc"), + "model", + "model" ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java index 50237bde23a37..d43299538bd12 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java @@ -26,10 +26,10 @@ import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType; @@ -139,12 +139,18 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { requestMap, equalTo( Map.of( - "input", List.of("abc"), - "model", "model", - "input_type", "document", - "output_dtype", "float", - "truncation", true, - "output_dimension", 1024 + "input", + List.of("abc"), + "model", + "model", + "input_type", + "document", + "output_dtype", + "float", + "truncation", + true, + "output_dimension", + 1024 ) ) ); @@ -208,12 +214,18 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I requestMap, is( Map.of( - "input", List.of("abc"), - "model", "model", - "input_type", "document", - "output_dtype", "int8", - "truncation", true, - "output_dimension", 1024 + "input", + List.of("abc"), + "model", + "model", + "input_type", + "document", + "output_dtype", + "int8", + "truncation", + true, + "output_dimension", + 1024 ) ) ); @@ -277,12 +289,18 @@ public void testExecute_ReturnsSuccessfulResponse_ForBinaryResponseType() throws requestMap, is( Map.of( - "input", List.of("abc"), - "model", "model", - "input_type", "document", - "output_dtype", "binary", - "truncation", true, - "output_dimension", 1024 + "input", + List.of("abc"), + "model", + "model", + "input_type", + "document", + "output_dtype", + "binary", + "truncation", + true, + "output_dimension", + 1024 ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java index 593cc803c4778..66b2287e9cb50 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java @@ -17,8 +17,8 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; -import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; import org.hamcrest.MatcherAssert; import java.io.IOException; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java index 1c9abb45a11ee..868849542457c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java @@ -47,11 +47,7 @@ public void testCreateRequest_UrlDefined() throws IOException { ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - MatcherAssert.assertThat(requestMap, is(Map.of( - "input", List.of("abc"), - "model", "model", - "output_dtype", "float" - ))); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model", "output_dtype", "float"))); } public void testCreateRequest_AllOptionsDefined() throws IOException { @@ -59,7 +55,8 @@ public void testCreateRequest_AllOptionsDefined() throws IOException { List.of("abc"), VoyageAIEmbeddingsModelTests.createModel( "url", - "secret", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + "secret", + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model" @@ -80,12 +77,10 @@ public void testCreateRequest_AllOptionsDefined() throws IOException { ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - MatcherAssert.assertThat(requestMap, is(Map.of( - "input", List.of("abc"), - "model", "model", - "input_type", "document", - "output_dtype", "float" - ))); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "model", "input_type", "document", "output_dtype", "float")) + ); } public void testCreateRequest_DimensionDefined() throws IOException { @@ -93,7 +88,8 @@ public void testCreateRequest_DimensionDefined() throws IOException { List.of("abc"), VoyageAIEmbeddingsModelTests.createModel( "url", - "secret", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + "secret", + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, 2048, "model" @@ -114,13 +110,23 @@ public void testCreateRequest_DimensionDefined() throws IOException { ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - MatcherAssert.assertThat(requestMap, is(Map.of( - "input", List.of("abc"), - "model", "model", - "input_type", "document", - "output_dtype", "float", - "output_dimension", 2048 - ))); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "input", + List.of("abc"), + "model", + "model", + "input_type", + "document", + "output_dtype", + "float", + "output_dimension", + 2048 + ) + ) + ); } public void testCreateRequest_EmbeddingTypeDefined() throws IOException { @@ -128,7 +134,8 @@ public void testCreateRequest_EmbeddingTypeDefined() throws IOException { List.of("abc"), VoyageAIEmbeddingsModelTests.createModel( "url", - "secret", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), + "secret", + new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, 2048, "model", @@ -150,13 +157,23 @@ public void testCreateRequest_EmbeddingTypeDefined() throws IOException { ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - MatcherAssert.assertThat(requestMap, is(Map.of( - "input", List.of("abc"), - "model", "model", - "input_type", "document", - "output_dtype", "int8", - "output_dimension", 2048 - ))); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "input", + List.of("abc"), + "model", + "model", + "input_type", + "document", + "output_dtype", + "int8", + "output_dimension", + 2048 + ) + ) + ); } public void testCreateRequest_InputTypeSearch() throws IOException { @@ -186,12 +203,10 @@ public void testCreateRequest_InputTypeSearch() throws IOException { ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - MatcherAssert.assertThat(requestMap, is(Map.of( - "input", List.of("abc"), - "model", "model", - "input_type", "query", - "output_dtype", "float" - ))); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "model", "input_type", "query", "output_dtype", "float")) + ); } public static VoyageAIEmbeddingsRequest createRequest(List input, VoyageAIEmbeddingsModel model) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java index 0bffd9c7dc268..a11d259200b98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java @@ -18,9 +18,9 @@ import java.util.List; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.sameInstance; public class VoyageAIRerankRequestTests extends ESTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java index 993c0e09145f7..a81f07936c10a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java @@ -49,12 +49,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( List.of("abc", "def"), - createModel( - "url", - "api_key", - null, - "voyage-3-large" - ) + createModel("url", "api_key", null, "voyage-3-large") ); InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( @@ -63,7 +58,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { ); assertThat( - ((InferenceTextEmbeddingFloatResults)parsedResults).embeddings(), + ((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(), is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.014539449F, -0.015288644F }))) ); } @@ -100,12 +95,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( List.of("abc", "def"), - createModel( - "url", - "api_key", - null, - "voyage-3-large" - ) + createModel("url", "api_key", null, "voyage-3-large") ); InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( @@ -114,7 +104,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException ); assertThat( - ((InferenceTextEmbeddingFloatResults)parsedResults).embeddings(), + ((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(), is( List.of( new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.014539449F, -0.015288644F }), @@ -148,12 +138,7 @@ public void testFromResponse_FailsWhenDataFieldIsNotPresent() { VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( List.of("abc", "def"), - createModel( - "url", - "api_key", - null, - "voyage-3-large" - ) + createModel("url", "api_key", null, "voyage-3-large") ); var thrownException = expectThrows( @@ -191,12 +176,7 @@ public void testFromResponse_FailsWhenDataFieldNotAnArray() { VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( List.of("abc", "def"), - createModel( - "url", - "api_key", - null, - "voyage-3-large" - ) + createModel("url", "api_key", null, "voyage-3-large") ); var thrownException = expectThrows( @@ -237,12 +217,7 @@ public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() { VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( List.of("abc", "def"), - createModel( - "url", - "api_key", - null, - "voyage-3-large" - ) + createModel("url", "api_key", null, "voyage-3-large") ); var thrownException = expectThrows( @@ -279,12 +254,7 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAString() { VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( List.of("abc", "def"), - createModel( - "url", - "api_key", - null, - "voyage-3-large" - ) + createModel("url", "api_key", null, "voyage-3-large") ); var thrownException = expectThrows( @@ -324,12 +294,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( List.of("abc", "def"), - createModel( - "url", - "api_key", - null, - "voyage-3-large" - ) + createModel("url", "api_key", null, "voyage-3-large") ); InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( @@ -338,7 +303,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio ); assertThat( - ((InferenceTextEmbeddingFloatResults)parsedResults).embeddings(), + ((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(), is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 1.0F }))) ); } @@ -366,12 +331,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( List.of("abc", "def"), - createModel( - "url", - "api_key", - null, - "voyage-3-large" - ) + createModel("url", "api_key", null, "voyage-3-large") ); InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( @@ -380,7 +340,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti ); assertThat( - ((InferenceTextEmbeddingFloatResults)parsedResults).embeddings(), + ((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(), is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 4.0294965E10F }))) ); } @@ -408,12 +368,7 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( List.of("abc", "def"), - createModel( - "url", - "api_key", - null, - "voyage-3-large" - ) + createModel("url", "api_key", null, "voyage-3-large") ); var thrownException = expectThrows( @@ -473,12 +428,7 @@ public void testFieldsInDifferentOrderServer() throws IOException { VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest( List.of("abc", "def"), - createModel( - "url", - "api_key", - null, - "voyage-3-large" - ) + createModel("url", "api_key", null, "voyage-3-large") ); InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse( @@ -486,13 +436,10 @@ public void testFieldsInDifferentOrderServer() throws IOException { new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8)) ); - assertThat( - parsedResults, - instanceOf(InferenceTextEmbeddingFloatResults.class) - ); + assertThat(parsedResults, instanceOf(InferenceTextEmbeddingFloatResults.class)); assertThat( - ((InferenceTextEmbeddingFloatResults)parsedResults).embeddings(), + ((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(), is( List.of( new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { -0.9F, 0.5F, 0.3F }), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandlerTests.java index 0c45fa1b18429..d032116d9a894 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIResponseHandlerTests.java @@ -18,7 +18,6 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.RetryException; -import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler; import org.elasticsearch.xpack.inference.external.request.Request; import org.hamcrest.MatcherAssert; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index b6b0bbefdf0a5..09b73dc260693 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -147,9 +147,6 @@ public static Map buildExpectationByte(List embeddings) } public static Map buildExpectationBinary(List embeddings) { - return Map.of( - "text_embedding_bits", - embeddings.stream().map(InferenceByteEmbedding::new).toList() - ); + return Map.of("text_embedding_bits", embeddings.stream().map(InferenceByteEmbedding::new).toList()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index d995bff1e9caa..e16e315770e8c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -1138,13 +1138,20 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of( - "input", List.of("abc"), - "model", "voyage-clip-v2", - "input_type", "document", - "output_dtype", "float", - "output_dimension", 1024 - )) + is( + Map.of( + "input", + List.of("abc"), + "model", + "voyage-clip-v2", + "input_type", + "document", + "output_dtype", + "float", + "output_dimension", + 1024 + ) + ) ); } } @@ -1210,13 +1217,23 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of( - "input", List.of("abc"), - "model", "voyage-clip-v2", - "input_type", "query", - "output_dtype", "float", - "output_dimension", 1024 - ))); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "input", + List.of("abc"), + "model", + "voyage-clip-v2", + "input_type", + "query", + "output_dtype", + "float", + "output_dimension", + 1024 + ) + ) + ); } } @@ -1265,12 +1282,10 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of( - "input", List.of("abc"), - "model", "voyage-clip-v2", - "output_dtype", "float", - "output_dimension", 1024 - ))); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "voyage-clip-v2", "output_dtype", "float", "output_dimension", 1024)) + ); } } @@ -1326,12 +1341,10 @@ public void testInfer_Embedding_Get_Response_NullInputType() throws IOException MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of( - "input", List.of("abc"), - "model", "voyage-clip-v2", - "output_dtype", "float", - "output_dimension", 1024 - ))); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "voyage-clip-v2", "output_dtype", "float", "output_dimension", 1024)) + ); } } @@ -1579,11 +1592,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of( - "query", "query", - "documents", List.of("candidate1", "candidate2", "candidate3"), - "model", "model" - )) + is(Map.of("query", "query", "documents", List.of("candidate1", "candidate2", "candidate3"), "model", "model")) ); } @@ -1741,13 +1750,23 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of( - "input", List.of("abc"), - "model", "voyage-clip-v2", - "input_type", "document", - "output_dtype", "float", - "output_dimension", 1024 - ))); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "input", + List.of("abc"), + "model", + "voyage-clip-v2", + "input_type", + "document", + "output_dtype", + "float", + "output_dimension", + 1024 + ) + ) + ); } } @@ -1853,12 +1872,10 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of( - "input", List.of("foo", "bar"), - "model", "voyage-clip-v2", - "output_dtype", "float", - "output_dimension", 1024 - ))); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("foo", "bar"), "model", "voyage-clip-v2", "output_dtype", "float", "output_dimension", 1024)) + ); } } @@ -1869,35 +1886,33 @@ public void testDefaultSimilarity() { @SuppressWarnings("checkstyle:LineLength") public void testGetConfiguration() throws Exception { try (var service = createVoyageAIService()) { - String content = XContentHelper.stripWhitespace( - """ - { - "service": "voyageai", - "name": "Voyage AI", - "task_types": ["text_embedding", "rerank"], - "configurations": { - "api_key": { - "description": "API Key for the provider you're connecting to.", - "label": "API Key", - "required": true, - "sensitive": true, - "updatable": true, - "type": "str", - "supported_task_types": ["text_embedding", "rerank"] - }, - "rate_limit.requests_per_minute": { - "description": "Minimize the number of rate limit errors.", - "label": "Rate Limit", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding", "rerank"] - } + String content = XContentHelper.stripWhitespace(""" + { + "service": "voyageai", + "name": "Voyage AI", + "task_types": ["text_embedding", "rerank"], + "configurations": { + "api_key": { + "description": "API Key for the provider you're connecting to.", + "label": "API Key", + "required": true, + "sensitive": true, + "updatable": true, + "type": "str", + "supported_task_types": ["text_embedding", "rerank"] + }, + "rate_limit.requests_per_minute": { + "description": "Minimize the number of rate limit errors.", + "label": "Rate Limit", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "rerank"] } } - """ - ); + } + """); InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( new BytesArray(content), XContentType.JSON diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java index 60eb46f016124..c4980e02ca42d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java @@ -156,9 +156,15 @@ public void testToXContent_WritesAllValues() throws IOException { XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); String xContentResult = Strings.toString(builder); - assertThat(xContentResult, is(""" - {"url":"url","model_id":"model",""" + """ - "rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10,"embedding_type":"float"}""")); + assertThat( + xContentResult, + is( + """ + {"url":"url","model_id":"model",""" + + """ + "rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10,"embedding_type":"float"}""" + ) + ); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java index 6823bfb9d21de..02c8f9ae677ef 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettingsTests.java @@ -85,8 +85,8 @@ public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap))); assertThat( thrownException.getMessage(), - containsString("field [top_k] is not of the expected type. The value [invalid] cannot be converted to a [Integer];" - )); + containsString("field [top_k] is not of the expected type. The value [invalid] cannot be converted to a [Integer];") + ); } public void testUpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { From cb9fd17861af1a92f52322100ee1306340c606c6 Mon Sep 17 00:00:00 2001 From: fzowl Date: Sun, 16 Feb 2025 23:46:41 +0100 Subject: [PATCH 17/20] Changes due to the comments --- .../org/elasticsearch/TransportVersions.java | 1 + .../InferenceNamedWriteablesProvider.java | 28 ++ .../voyageai/VoyageAIActionCreator.java | 10 +- .../http/sender/VoyageAIRequestManager.java | 28 +- .../voyageai/VoyageAIEmbeddingsRequest.java | 13 +- .../request/voyageai/VoyageAIRequest.java | 2 +- .../voyageai/VoyageAIRerankRequest.java | 13 +- .../VoyageAIEmbeddingsResponseEntity.java | 194 ++++++------ .../VoyageAIRerankResponseEntity.java | 179 +++++------ .../external/voyageai/VoyageAIAccount.java | 17 +- .../services/voyageai/VoyageAIModel.java | 38 ++- .../services/voyageai/VoyageAIService.java | 59 ++-- .../voyageai/VoyageAIServiceFields.java | 1 - .../voyageai/VoyageAIServiceSettings.java | 41 +-- .../embeddings/VoyageAIEmbeddingsModel.java | 31 +- .../VoyageAIEmbeddingsServiceSettings.java | 94 +++--- .../voyageai/rerank/VoyageAIRerankModel.java | 26 +- .../rerank/VoyageAIRerankTaskSettings.java | 6 - .../VoyageAIEmbeddingsActionTests.java | 16 +- .../voyageai/VoyageAIRequestTests.java | 4 +- ...VoyageAIEmbeddingsResponseEntityTests.java | 61 ++-- .../VoyageAIRerankResponseEntityTests.java | 4 +- .../VoyageAIServiceSettingsTests.java | 80 +---- .../voyageai/VoyageAIServiceTests.java | 279 +++++++++--------- .../VoyageAIEmbeddingsModelTests.java | 24 +- ...oyageAIEmbeddingsServiceSettingsTests.java | 175 +++++++++-- .../rerank/VoyageAIRerankModelTests.java | 18 +- .../VoyageAIRerankServiceSettingsTests.java | 13 +- 28 files changed, 803 insertions(+), 652 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 96f21e2e6369c..41088a6cdd58c 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -173,6 +173,7 @@ static TransportVersion def(int id) { public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_0_00); public static final TransportVersion ML_INFERENCE_IBM_WATSONX_RERANK_ADDED = def(8_840_0_00); public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_X = def(8_840_0_01); + public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_840_0_03); public static final TransportVersion ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index e8dc763116707..1c6e60ba2426d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -88,6 +88,11 @@ import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; import java.util.ArrayList; import java.util.List; @@ -140,6 +145,7 @@ public static List getNamedWriteables() { addEisNamedWriteables(namedWriteables); addAlibabaCloudSearchNamedWriteables(namedWriteables); addJinaAINamedWriteables(namedWriteables); + addVoyageAINamedWriteables(namedWriteables); addUnifiedNamedWriteables(namedWriteables); @@ -610,6 +616,28 @@ private static void addJinaAINamedWriteables(List ); } + private static void addVoyageAINamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIServiceSettings.NAME, VoyageAIServiceSettings::new) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + VoyageAIEmbeddingsServiceSettings.NAME, + VoyageAIEmbeddingsServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIEmbeddingsTaskSettings.NAME, VoyageAIEmbeddingsTaskSettings::new) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIRerankServiceSettings.NAME, VoyageAIRerankServiceSettings::new) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIRerankTaskSettings.NAME, VoyageAIRerankTaskSettings::new) + ); + } + private static void addEisNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java index e5361e9e45d0c..6a4a9e5f93639 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java @@ -37,10 +37,7 @@ public VoyageAIActionCreator(Sender sender, ServiceComponents serviceComponents) @Override public ExecutableAction create(VoyageAIEmbeddingsModel model, Map taskSettings, InputType inputType) { var overriddenModel = VoyageAIEmbeddingsModel.of(model, taskSettings, inputType); - var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( - overriddenModel.getServiceSettings().getCommonSettings().uri(), - "VoyageAI embeddings" - ); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI embeddings"); var requestCreator = VoyageAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool()); return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); } @@ -48,10 +45,7 @@ public ExecutableAction create(VoyageAIEmbeddingsModel model, Map taskSettings) { var overriddenModel = VoyageAIRerankModel.of(model, taskSettings); - var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( - overriddenModel.getServiceSettings().getCommonSettings().uri(), - "VoyageAI rerank" - ); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI rerank"); var requestCreator = VoyageAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool()); return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRequestManager.java index 819bb74f237d6..99a0617ff510a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRequestManager.java @@ -10,9 +10,33 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel; +import java.util.Map; import java.util.Objects; abstract class VoyageAIRequestManager extends BaseRequestManager { + private static final String DEFAULT_MODEL_FAMILY = "default_model_family"; + private static final Map MODEL_TO_MODEL_FAMILY = Map.of( + "voyage-multimodal-3", + "embed_multimodal", + "voyage-3-large", + "embed_large", + "voyage-code-3", + "embed_large", + "voyage-3", + "embed_medium", + "voyage-3-lite", + "embed_small", + "voyage-finance-2", + "embed_large", + "voyage-law-2", + "embed_large", + "voyage-code-2", + "embed_large", + "rerank-2", + "rerank_large", + "rerank-2-lite", + "rerank_small" + ); protected VoyageAIRequestManager(ThreadPool threadPool, VoyageAIModel model) { super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); @@ -21,8 +45,10 @@ protected VoyageAIRequestManager(ThreadPool threadPool, VoyageAIModel model) { record RateLimitGrouping(int apiKeyHash) { public static RateLimitGrouping of(VoyageAIModel model) { Objects.requireNonNull(model); + String modelId = model.getServiceSettings().modelId(); + String modelFamily = MODEL_TO_MODEL_FAMILY.getOrDefault(modelId, DEFAULT_MODEL_FAMILY); - return new RateLimitGrouping(model.apiKey().hashCode()); + return new RateLimitGrouping(modelFamily.hashCode()); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java index ee01e25eafb33..a24f5dc8e14ea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.external.request.voyageai; import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.utils.URIBuilder; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.xpack.inference.external.request.HttpRequest; @@ -19,7 +18,6 @@ import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings; import java.net.URI; -import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Objects; @@ -36,7 +34,7 @@ public class VoyageAIEmbeddingsRequest extends VoyageAIRequest { public VoyageAIEmbeddingsRequest(List input, VoyageAIEmbeddingsModel embeddingsModel) { Objects.requireNonNull(embeddingsModel); - account = VoyageAIAccount.of(embeddingsModel, VoyageAIEmbeddingsRequest::buildDefaultUri); + account = VoyageAIAccount.of(embeddingsModel); this.input = Objects.requireNonNull(input); serviceSettings = embeddingsModel.getServiceSettings(); taskSettings = embeddingsModel.getTaskSettings(); @@ -54,7 +52,7 @@ public HttpRequest createHttpRequest() { ); httpPost.setEntity(byteEntity); - decorateWithAuthHeader(httpPost, account); + decorateWithHeaders(httpPost, account); return new HttpRequest(httpPost, getInferenceEntityId()); } @@ -86,11 +84,4 @@ public VoyageAIEmbeddingsTaskSettings getTaskSettings() { public VoyageAIEmbeddingsServiceSettings getServiceSettings() { return serviceSettings; } - - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(VoyageAIUtils.HOST) - .setPathSegments(VoyageAIUtils.VERSION_1, VoyageAIUtils.EMBEDDINGS_PATH) - .build(); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequest.java index de5dfe4db07e6..5455a2f0301f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequest.java @@ -17,7 +17,7 @@ public abstract class VoyageAIRequest implements Request { - public static void decorateWithAuthHeader(HttpPost request, VoyageAIAccount account) { + public static void decorateWithHeaders(HttpPost request, VoyageAIAccount account) { request.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); request.setHeader(createAuthBearerHeader(account.apiKey())); request.setHeader(VoyageAIUtils.createRequestSourceHeader()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java index d1a29d69a482f..37d15fe1fe2c5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.external.request.voyageai; import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.utils.URIBuilder; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.xpack.inference.external.request.HttpRequest; @@ -18,7 +17,6 @@ import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; import java.net.URI; -import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Objects; @@ -35,7 +33,7 @@ public class VoyageAIRerankRequest extends VoyageAIRequest { public VoyageAIRerankRequest(String query, List input, VoyageAIRerankModel model) { Objects.requireNonNull(model); - this.account = VoyageAIAccount.of(model, VoyageAIRerankRequest::buildDefaultUri); + this.account = VoyageAIAccount.of(model); this.input = Objects.requireNonNull(input); this.query = Objects.requireNonNull(query); taskSettings = model.getTaskSettings(); @@ -52,7 +50,7 @@ public HttpRequest createHttpRequest() { ); httpPost.setEntity(byteEntity); - decorateWithAuthHeader(httpPost, account); + decorateWithHeaders(httpPost, account); return new HttpRequest(httpPost, getInferenceEntityId()); } @@ -76,11 +74,4 @@ public Request truncate() { public boolean[] getTruncationInfo() { return null; } - - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(VoyageAIUtils.HOST) - .setPathSegments(VoyageAIUtils.VERSION_1, VoyageAIUtils.RERANK_PATH) - .build(); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java index 218ef932c9bf8..472b277e7e0d8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java @@ -10,7 +10,10 @@ package org.elasticsearch.xpack.inference.external.response.voyageai; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; @@ -22,23 +25,17 @@ import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest; -import org.elasticsearch.xpack.inference.external.response.XContentUtils; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType; import java.io.IOException; import java.util.Arrays; import java.util.List; -import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType.toLowerCase; public class VoyageAIEmbeddingsResponseEntity { - private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in VoyageAI embeddings response"; - private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes(); private static String supportedEmbeddingTypes() { @@ -50,6 +47,93 @@ private static String supportedEmbeddingTypes() { return String.join(", ", validTypes); } + record EmbeddingInt8Result(List entries, String model, String object, @Nullable Usage usage) { + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + EmbeddingInt8Result.class.getSimpleName(), + args -> new EmbeddingInt8Result((List) args[0], (String) args[1], (String) args[2], (Usage) args[3]) + ); + + static { + PARSER.declareObjectArray(constructorArg(), EmbeddingInt8ResultEntry.PARSER::apply, new ParseField("data")); + PARSER.declareString(constructorArg(), new ParseField("model")); + PARSER.declareString(constructorArg(), new ParseField("object")); + PARSER.declareObject(optionalConstructorArg(), Usage.PARSER::apply, new ParseField("usage")); + } + } + + record EmbeddingInt8ResultEntry(String object, Integer index, List embedding) { + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + EmbeddingInt8ResultEntry.class.getSimpleName(), + args -> new EmbeddingInt8ResultEntry((String) args[0], (Integer) args[1], (List) args[2]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("object")); + PARSER.declareInt(constructorArg(), new ParseField("index")); + PARSER.declareIntArray(constructorArg(), new ParseField("embedding")); + } + + private static void checkByteBounds(Integer value) { + if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { + throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte"); + } + } + + public InferenceByteEmbedding toInferenceByteEmbedding() { + embedding.forEach(EmbeddingInt8ResultEntry::checkByteBounds); + return InferenceByteEmbedding.of(embedding.stream().map(Integer::byteValue).toList()); + } + } + + record EmbeddingFloatResult(List entries, String model, String object, @Nullable Usage usage) { + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + EmbeddingFloatResult.class.getSimpleName(), + args -> new EmbeddingFloatResult((List) args[0], (String) args[1], (String) args[2], (Usage) args[3]) + ); + + static { + PARSER.declareObjectArray(constructorArg(), EmbeddingFloatResultEntry.PARSER::apply, new ParseField("data")); + PARSER.declareString(constructorArg(), new ParseField("model")); + PARSER.declareString(constructorArg(), new ParseField("object")); + PARSER.declareObject(optionalConstructorArg(), Usage.PARSER::apply, new ParseField("usage")); + } + } + + record EmbeddingFloatResultEntry(String object, Integer index, List embedding) { + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + EmbeddingFloatResultEntry.class.getSimpleName(), + args -> new EmbeddingFloatResultEntry((String) args[0], (Integer) args[1], (List) args[2]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("object")); + PARSER.declareInt(constructorArg(), new ParseField("index")); + PARSER.declareFloatArray(constructorArg(), new ParseField("embedding")); + } + + public InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding toInferenceFloatEmbedding() { + return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embedding); + } + } + + record Usage(Integer totalTokens) { + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Usage.class.getSimpleName(), + args -> new Usage((Integer) args[0]) + ); + + static { + PARSER.declareInt(constructorArg(), new ParseField("total_tokens")); + } + } + /** * Parses the VoyageAI json response. * For a request like: @@ -100,33 +184,24 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r VoyageAIEmbeddingType embeddingType = ((VoyageAIEmbeddingsRequest) request).getServiceSettings().getEmbeddingType(); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { - moveToFirstToken(jsonParser); - - XContentParser.Token token = jsonParser.currentToken(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); - - positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); - if (embeddingType == null || embeddingType == VoyageAIEmbeddingType.FLOAT) { - List embeddingList = parseList( - jsonParser, - VoyageAIEmbeddingsResponseEntity::parseEmbeddingObjectFloat - ); + var embeddingResult = EmbeddingFloatResult.PARSER.apply(jsonParser, null); + List embeddingList = embeddingResult.entries.stream() + .map(EmbeddingFloatResultEntry::toInferenceFloatEmbedding) + .toList(); return new InferenceTextEmbeddingFloatResults(embeddingList); } else if (embeddingType == VoyageAIEmbeddingType.INT8) { - List embeddingList = parseList( - jsonParser, - VoyageAIEmbeddingsResponseEntity::parseEmbeddingObjectByte - ); - + var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null); + List embeddingList = embeddingResult.entries.stream() + .map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding) + .toList(); return new InferenceTextEmbeddingByteResults(embeddingList); } else if (embeddingType == VoyageAIEmbeddingType.BIT || embeddingType == VoyageAIEmbeddingType.BINARY) { - List embeddingList = parseList( - jsonParser, - VoyageAIEmbeddingsResponseEntity::parseEmbeddingObjectBit - ); - + var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null); + List embeddingList = embeddingResult.entries.stream() + .map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding) + .toList(); return new InferenceTextEmbeddingBitResults(embeddingList); } else { throw new IllegalArgumentException( @@ -136,66 +211,5 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r } } - private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseEmbeddingObjectFloat(XContentParser parser) - throws IOException { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - - positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); - - List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); - // parse and discard the rest of the object - consumeUntilObjectEnd(parser); - - return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList); - } - - private static InferenceByteEmbedding parseEmbeddingObjectByte(XContentParser parser) throws IOException { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - - positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); - - List embeddingValuesList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseEmbeddingInt8Entry); - // parse and discard the rest of the object - consumeUntilObjectEnd(parser); - - return InferenceByteEmbedding.of(embeddingValuesList); - } - - private static InferenceByteEmbedding parseEmbeddingObjectBit(XContentParser parser) throws IOException { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - - positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); - - List embeddingValuesList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseEmbeddingBitEntry); - // parse and discard the rest of the object - consumeUntilObjectEnd(parser); - - return InferenceByteEmbedding.of(embeddingValuesList); - } - - private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException { - XContentParser.Token token = parser.currentToken(); - ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); - var parsedByte = parser.shortValue(); - checkByteBounds(parsedByte); - - return (byte) parsedByte; - } - - private static Byte parseEmbeddingBitEntry(XContentParser parser) throws IOException { - XContentParser.Token token = parser.currentToken(); - ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); - var parsedBit = parser.shortValue(); - checkByteBounds(parsedBit); - - return (byte) parsedBit; - } - - private static void checkByteBounds(short value) { - if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { - throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte"); - } - } - private VoyageAIEmbeddingsResponseEntity() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java index 2a1c1d868b248..e0ef73bd2c7ae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java @@ -12,7 +12,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; @@ -21,121 +24,103 @@ import org.elasticsearch.xpack.inference.external.http.HttpResult; import java.io.IOException; +import java.util.List; -import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; -import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField; -import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; public class VoyageAIRerankResponseEntity { private static final Logger logger = LogManager.getLogger(VoyageAIRerankResponseEntity.class); - /** - * Parses the VoyageAI ranked response. - * For a request like: - * "model": "rerank-2", - * "query": "What is the capital of the United States?", - * "top_k": 2, - * "documents": ["Carson City is the capital city of the American state of Nevada.", - * "The Commonwealth of the Northern Mariana ... Its capital is Saipan.", - * "Washington, D.C. (also known as simply Washington or D.C., ... It is a federal district.", - * "Capital punishment (the death penalty) ... As of 2017, capital punishment is legal in 30 of the 50 states."] - *

- * The response will look like (without whitespace): - * { - * "object": "list", - * "data": [ - * { - * "relevance_score": 0.4375, - * "index": 0 - * }, - * { - * "relevance_score": 0.421875, - * "index": 1 - * } - * ], - * "model": "rerank-2", - * "usage": { - * "total_tokens": 26 - * } - * } - * @param response the http response from VoyageAI - * @return the parsed response - * @throws IOException if there is an error parsing the response - */ - public static InferenceServiceResults fromResponse(HttpResult response) throws IOException { - var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + record RerankResult(List entries, String model, String object, @Nullable Usage usage) { - try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { - moveToFirstToken(jsonParser); + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + RerankResult.class.getSimpleName(), + args -> new RerankResult((List) args[0], (String) args[1], (String) args[2], (Usage) args[3]) + ); + + static { + PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("data")); + PARSER.declareString(constructorArg(), new ParseField("model")); + PARSER.declareString(constructorArg(), new ParseField("object")); + PARSER.declareObject(optionalConstructorArg(), Usage.PARSER::apply, new ParseField("usage")); + } + } - XContentParser.Token token = jsonParser.currentToken(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + record RerankResultEntry(Float relevanceScore, Integer index, @Nullable String document) { - positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + RerankResultEntry.class.getSimpleName(), + args -> new RerankResultEntry((Float) args[0], (Integer) args[1], (String) args[2]) + ); - token = jsonParser.currentToken(); - if (token == XContentParser.Token.START_ARRAY) { - return new RankedDocsResults(parseList(jsonParser, VoyageAIRerankResponseEntity::parseRankedDocObject)); - } else { - throwUnknownToken(token, jsonParser); - } + static { + PARSER.declareFloat(constructorArg(), new ParseField("relevance_score")); + PARSER.declareInt(constructorArg(), new ParseField("index")); + PARSER.declareString(optionalConstructorArg(), new ParseField("document")); + } - // This should never be reached. The above code should either return successfully or hit the throwUnknownToken - // or throw a parsing exception - throw new IllegalStateException("Reached an invalid state while parsing the VoyageAI response"); + public RankedDocsResults.RankedDoc toRankedDoc() { + return new RankedDocsResults.RankedDoc(index, relevanceScore, document); } } - private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - int index = -1; - float relevanceScore = -1; - String documentText = null; - parser.nextToken(); - while (parser.currentToken() != XContentParser.Token.END_OBJECT) { - if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { - switch (parser.currentName()) { - case "index": - parser.nextToken(); // move to VALUE_NUMBER - index = parser.intValue(); - parser.nextToken(); // move to next FIELD_NAME or END_OBJECT - break; - case "relevance_score": - parser.nextToken(); // move to VALUE_NUMBER - relevanceScore = parser.floatValue(); - parser.nextToken(); // move to next FIELD_NAME or END_OBJECT - break; - case "document": - parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - documentText = parser.text(); - parser.nextToken();// move past END_OBJECT - // parser should now be at the next FIELD_NAME or END_OBJECT - break; - default: - throwUnknownField(parser.currentName(), parser); - } - } else { - parser.nextToken(); - } - } + record Usage(Integer totalTokens) { - if (index == -1) { - logger.warn("Failed to find required field [index] in VoyageAI rerank response"); - } - if (relevanceScore == -1) { - logger.warn("Failed to find required field [relevance_score] in VoyageAI rerank response"); + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Usage.class.getSimpleName(), + args -> new Usage((Integer) args[0]) + ); + + static { + PARSER.declareInt(constructorArg(), new ParseField("total_tokens")); } - // documentText may or may not be present depending on the request parameter + } + + /** + * Parses the VoyageAI ranked response. + * For a request like: + * "model": "rerank-2", + * "query": "What is the capital of the United States?", + * "top_k": 2, + * "documents": ["Carson City is the capital city of the American state of Nevada.", + * "The Commonwealth of the Northern Mariana ... Its capital is Saipan.", + * "Washington, D.C. (also known as simply Washington or D.C., ... It is a federal district.", + * "Capital punishment (the death penalty) ... As of 2017, capital punishment is legal in 30 of the 50 states."] + *

+ * The response will look like (without whitespace): + * { + * "object": "list", + * "data": [ + * { + * "relevance_score": 0.4375, + * "index": 0 + * }, + * { + * "relevance_score": 0.421875, + * "index": 1 + * } + * ], + * "model": "rerank-2", + * "usage": { + * "total_tokens": 26 + * } + * } + * @param response the http response from VoyageAI + * @return the parsed response + * @throws IOException if there is an error parsing the response + */ + public static InferenceServiceResults fromResponse(HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + var rerankResult = RerankResult.PARSER.apply(jsonParser, null); - return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText); + return new RankedDocsResults(rerankResult.entries.stream().map(RerankResultEntry::toRankedDoc).toList()); + } } private VoyageAIRerankResponseEntity() {} - - static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in VoyageAI rerank response"; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIAccount.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIAccount.java index eda5b038b3c3e..bb7de2fc8ad0c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIAccount.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/voyageai/VoyageAIAccount.java @@ -7,22 +7,25 @@ package org.elasticsearch.xpack.inference.external.voyageai; -import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel; import java.net.URI; import java.net.URISyntaxException; import java.util.Objects; -import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; - public record VoyageAIAccount(URI uri, SecureString apiKey) { - public static VoyageAIAccount of(VoyageAIModel model, CheckedSupplier uriBuilder) { - var uri = buildUri(model.uri(), "VoyageAI", uriBuilder); - - return new VoyageAIAccount(uri, model.apiKey()); + public static VoyageAIAccount of(VoyageAIModel model) { + try { + var uri = model.buildUri(); + return new VoyageAIAccount(uri, model.apiKey()); + } catch (URISyntaxException e) { + // using bad request here so that potentially sensitive URL information does not get logged + throw new ElasticsearchStatusException("Failed to construct VoyageAI URL", RestStatus.BAD_REQUEST, e); + } } public VoyageAIAccount { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIModel.java index c8b953c1b8f97..e63a716b96617 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIModel.java @@ -21,37 +21,52 @@ import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; import java.net.URI; +import java.net.URISyntaxException; import java.util.Map; import java.util.Objects; public abstract class VoyageAIModel extends Model { private final SecureString apiKey; private final VoyageAIRateLimitServiceSettings rateLimitServiceSettings; + protected final URI uri; public VoyageAIModel( ModelConfigurations configurations, ModelSecrets secrets, @Nullable ApiKeySecrets apiKeySecrets, VoyageAIRateLimitServiceSettings rateLimitServiceSettings + ) { + this(configurations, secrets, apiKeySecrets, rateLimitServiceSettings, null); + } + + public VoyageAIModel( + ModelConfigurations configurations, + ModelSecrets secrets, + @Nullable ApiKeySecrets apiKeySecrets, + VoyageAIRateLimitServiceSettings rateLimitServiceSettings, + String url ) { super(configurations, secrets); this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); - apiKey = ServiceUtils.apiKey(apiKeySecrets); + this.apiKey = ServiceUtils.apiKey(apiKeySecrets); + this.uri = url == null ? null : URI.create(url); } protected VoyageAIModel(VoyageAIModel model, TaskSettings taskSettings) { super(model, taskSettings); - rateLimitServiceSettings = model.rateLimitServiceSettings(); - apiKey = model.apiKey(); + this.rateLimitServiceSettings = model.rateLimitServiceSettings(); + this.apiKey = model.apiKey(); + this.uri = model.uri; } protected VoyageAIModel(VoyageAIModel model, ServiceSettings serviceSettings) { super(model, serviceSettings); - rateLimitServiceSettings = model.rateLimitServiceSettings(); - apiKey = model.apiKey(); + this.rateLimitServiceSettings = model.rateLimitServiceSettings(); + this.apiKey = model.apiKey(); + this.uri = model.uri; } public SecureString apiKey() { @@ -64,5 +79,16 @@ public VoyageAIRateLimitServiceSettings rateLimitServiceSettings() { public abstract ExecutableAction accept(VoyageAIActionVisitor creator, Map taskSettings, InputType inputType); - public abstract URI uri(); + public URI uri() { + return uri; + } + + public URI buildUri() throws URISyntaxException { + if (uri == null) { + return buildRequestUri(); + } + return uri; + } + + protected abstract URI buildRequestUri() throws URISyntaxException; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 63d1b8a4acc95..f92779de9b7f5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -25,6 +25,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; @@ -49,6 +50,7 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; @@ -63,10 +65,29 @@ public class VoyageAIService extends SenderService { private static final String SERVICE_NAME = "Voyage AI"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK); - private static final int DEFAULT_VOYAGE_2_BATCH_SIZE = 72; - private static final int DEFAULT_VOYAGE_3_LITE_BATCH_SIZE = 30; - private static final int DEFAULT_VOYAGE_3_BATCH_SIZE = 10; - private static final int DEFAULT_BATCH_SIZE = 7; + private static final Integer DEFAULT_BATCH_SIZE = 7; + private static final Map MODEL_BATCH_SIZES = Map.of( + "voyage-multimodal-3", + 7, + "voyage-3-large", + 7, + "voyage-code-3", + 7, + "voyage-3", + 10, + "voyage-3-lite", + 30, + "voyage-finance-2", + 7, + "voyage-law-2", + 7, + "voyage-code-2", + 7, + "voyage-2", + 72, + "voyage-02", + 72 + ); public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); @@ -281,17 +302,7 @@ protected void doChunkedInfer( } private static int getBatchSize(VoyageAIModel model) { - int maxBatchSize = DEFAULT_BATCH_SIZE; - - if ("voyage-2".equals(model.getServiceSettings().modelId()) || "voyage-02".equals(model.getServiceSettings().modelId())) { - maxBatchSize = DEFAULT_VOYAGE_2_BATCH_SIZE; - } else if ("voyage-3-lite".equals(model.getServiceSettings().modelId())) { - maxBatchSize = DEFAULT_VOYAGE_3_LITE_BATCH_SIZE; - } else if ("voyage-3".equals(model.getServiceSettings().modelId())) { - maxBatchSize = DEFAULT_VOYAGE_3_BATCH_SIZE; - } - - return maxBatchSize; + return MODEL_BATCH_SIZES.getOrDefault(model.getServiceSettings().modelId(), DEFAULT_BATCH_SIZE); } /** @@ -313,17 +324,18 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { var similarityFromModel = serviceSettings.similarity(); var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel; var maxInputTokens = serviceSettings.maxInputTokens(); + var dimensionSetByUser = serviceSettings.dimensionsSetByUser(); var updatedServiceSettings = new VoyageAIEmbeddingsServiceSettings( new VoyageAIServiceSettings( - serviceSettings.getCommonSettings().uri(), serviceSettings.getCommonSettings().modelId(), serviceSettings.getCommonSettings().rateLimitSettings() ), serviceSettings.getEmbeddingType(), similarityToUse, embeddingSize, - maxInputTokens + maxInputTokens, + dimensionSetByUser ); return new VoyageAIEmbeddingsModel(embeddingsModel, updatedServiceSettings); @@ -358,6 +370,19 @@ public static InferenceServiceConfiguration get() { () -> { var configurationMap = new HashMap(); + configurationMap.put( + MODEL_ID, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "The name of the model to use for the inference task." + ) + .setLabel("Model ID") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceFields.java index 6d6e84f2e0c11..b36f212b61e5d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceFields.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceFields.java @@ -9,5 +9,4 @@ public class VoyageAIServiceFields { public static final String TRUNCATION = "truncation"; - public static final String OUTPUT_DIMENSION = "output_dimension"; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java index 76450913453ff..75497d1a4b4f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettings.java @@ -23,14 +23,9 @@ import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.io.IOException; -import java.net.URI; import java.util.Map; import java.util.Objects; -import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; public class VoyageAIServiceSettings extends FilteredXContentObject implements ServiceSettings, VoyageAIRateLimitServiceSettings { @@ -44,8 +39,6 @@ public class VoyageAIServiceSettings extends FilteredXContentObject implements S public static VoyageAIServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); - String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); - URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of( map, DEFAULT_RATE_LIMIT_SETTINGS, @@ -60,26 +53,19 @@ public static VoyageAIServiceSettings fromMap(Map map, Configura throw validationException; } - return new VoyageAIServiceSettings(uri, modelId, rateLimitSettings); + return new VoyageAIServiceSettings(modelId, rateLimitSettings); } - private final URI uri; private final String modelId; private final RateLimitSettings rateLimitSettings; - public VoyageAIServiceSettings(@Nullable URI uri, String modelId, @Nullable RateLimitSettings rateLimitSettings) { - this.uri = uri; + public VoyageAIServiceSettings(String modelId, @Nullable RateLimitSettings rateLimitSettings) { this.modelId = Objects.requireNonNull(modelId); this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); } - public VoyageAIServiceSettings(@Nullable String url, String modelId, @Nullable RateLimitSettings rateLimitSettings) { - this(createOptionalUri(url), modelId, rateLimitSettings); - } - public VoyageAIServiceSettings(StreamInput in) throws IOException { - uri = createOptionalUri(in.readOptionalString()); - modelId = in.readOptionalString(); + modelId = in.readString(); rateLimitSettings = new RateLimitSettings(in); } @@ -88,10 +74,6 @@ public RateLimitSettings rateLimitSettings() { return rateLimitSettings; } - public URI uri() { - return uri; - } - @Override public String modelId() { return modelId; @@ -118,12 +100,7 @@ public XContentBuilder toXContentFragment(XContentBuilder builder, Params params @Override public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { - if (uri != null) { - builder.field(URL, uri.toString()); - } - if (modelId != null) { - builder.field(MODEL_ID, modelId); - } + builder.field(MODEL_ID, modelId); rateLimitSettings.toXContent(builder, params); return builder; @@ -136,9 +113,7 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { - var uriToWrite = uri != null ? uri.toString() : null; - out.writeOptionalString(uriToWrite); - out.writeOptionalString(modelId); + out.writeString(modelId); rateLimitSettings.writeTo(out); } @@ -147,13 +122,11 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; VoyageAIServiceSettings that = (VoyageAIServiceSettings) o; - return Objects.equals(uri, that.uri) - && Objects.equals(modelId, that.modelId) - && Objects.equals(rateLimitSettings, that.rateLimitSettings); + return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings); } @Override public int hashCode() { - return Objects.hash(uri, modelId, rateLimitSettings); + return Objects.hash(modelId, rateLimitSettings); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModel.java index b20142ca32970..41194f6862a44 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModel.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.voyageai.embeddings; +import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; @@ -15,13 +16,17 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionVisitor; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel; import java.net.URI; +import java.net.URISyntaxException; import java.util.Map; +import static org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils.HOST; + public class VoyageAIEmbeddingsModel extends VoyageAIModel { public static VoyageAIEmbeddingsModel of(VoyageAIEmbeddingsModel model, Map taskSettings, InputType inputType) { var requestTaskSettings = VoyageAIEmbeddingsTaskSettings.fromMap(taskSettings); @@ -67,6 +72,24 @@ public VoyageAIEmbeddingsModel( ); } + VoyageAIEmbeddingsModel( + String modelId, + String service, + String url, + VoyageAIEmbeddingsServiceSettings serviceSettings, + VoyageAIEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(modelId, TaskType.TEXT_EMBEDDING, service, serviceSettings, taskSettings, chunkingSettings), + new ModelSecrets(secretSettings), + secretSettings, + serviceSettings.getCommonSettings(), + url + ); + } + private VoyageAIEmbeddingsModel(VoyageAIEmbeddingsModel model, VoyageAIEmbeddingsTaskSettings taskSettings) { super(model, taskSettings); } @@ -95,8 +118,10 @@ public ExecutableAction accept(VoyageAIActionVisitor visitor, Map map, ConfigurationParseContext context) { + return switch (context) { + case REQUEST -> fromRequestMap(map, context); + case PERSISTENT -> fromPersistentMap(map, context); + }; + } + + private static VoyageAIEmbeddingsServiceSettings fromRequestMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + var commonServiceSettings = VoyageAIServiceSettings.fromMap(map, context); + + VoyageAIEmbeddingType embeddingTypes = parseEmbeddingType(map, context, validationException); + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new VoyageAIEmbeddingsServiceSettings(commonServiceSettings, embeddingTypes, similarity, dims, maxInputTokens, dims != null); + } + + private static VoyageAIEmbeddingsServiceSettings fromPersistentMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); var commonServiceSettings = VoyageAIServiceSettings.fromMap(map, context); @@ -58,11 +82,23 @@ public static VoyageAIEmbeddingsServiceSettings fromMap(Map map, Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class); + if (dimensionsSetByUser == null) { + dimensionsSetByUser = Boolean.FALSE; + } + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new VoyageAIEmbeddingsServiceSettings(commonServiceSettings, embeddingTypes, similarity, dims, maxInputTokens); + return new VoyageAIEmbeddingsServiceSettings( + commonServiceSettings, + embeddingTypes, + similarity, + dims, + maxInputTokens, + dimensionsSetByUser + ); } static VoyageAIEmbeddingType parseEmbeddingType( @@ -71,7 +107,7 @@ static VoyageAIEmbeddingType parseEmbeddingType( ValidationException validationException ) { return switch (context) { - case REQUEST -> Objects.requireNonNullElse( + case REQUEST, PERSISTENT -> Objects.requireNonNullElse( extractOptionalEnum( map, EMBEDDING_TYPE, @@ -82,59 +118,31 @@ static VoyageAIEmbeddingType parseEmbeddingType( ), VoyageAIEmbeddingType.FLOAT ); - case PERSISTENT -> { - var embeddingType = ServiceUtils.extractOptionalString( - map, - EMBEDDING_TYPE, - ModelConfigurations.SERVICE_SETTINGS, - validationException - ); - yield fromVoyageAIOrDenseVectorEnumValues(embeddingType, validationException); - } }; } - static VoyageAIEmbeddingType fromVoyageAIOrDenseVectorEnumValues(String enumString, ValidationException validationException) { - if (enumString == null) { - return VoyageAIEmbeddingType.FLOAT; - } - - try { - return VoyageAIEmbeddingType.fromString(enumString); - } catch (IllegalArgumentException ae) { - try { - return VoyageAIEmbeddingType.fromElementType(DenseVectorFieldMapper.ElementType.fromString(enumString)); - } catch (IllegalArgumentException iae) { - var validValuesAsStrings = VoyageAIEmbeddingType.SUPPORTED_ELEMENT_TYPES.stream() - .map(value -> value.toString().toLowerCase(Locale.ROOT)) - .toArray(String[]::new); - validationException.addValidationError( - ServiceUtils.invalidValue(EMBEDDING_TYPE, ModelConfigurations.SERVICE_SETTINGS, enumString, validValuesAsStrings) - ); - return null; - } - } - } - private final VoyageAIServiceSettings commonSettings; private final VoyageAIEmbeddingType embeddingType; private final SimilarityMeasure similarity; private final Integer dimensions; private final Integer maxInputTokens; + private final Boolean dimensionsSetByUser; public VoyageAIEmbeddingsServiceSettings( VoyageAIServiceSettings commonSettings, @Nullable VoyageAIEmbeddingType embeddingType, @Nullable SimilarityMeasure similarity, @Nullable Integer dimensions, - @Nullable Integer maxInputTokens + @Nullable Integer maxInputTokens, + Boolean dimensionsSetByUser ) { this.commonSettings = commonSettings; this.similarity = similarity; this.dimensions = dimensions; this.maxInputTokens = maxInputTokens; this.embeddingType = embeddingType; + this.dimensionsSetByUser = Objects.requireNonNull(dimensionsSetByUser); } public VoyageAIEmbeddingsServiceSettings(StreamInput in) throws IOException { @@ -143,6 +151,7 @@ public VoyageAIEmbeddingsServiceSettings(StreamInput in) throws IOException { this.dimensions = in.readOptionalVInt(); this.maxInputTokens = in.readOptionalVInt(); this.embeddingType = Objects.requireNonNullElse(in.readOptionalEnum(VoyageAIEmbeddingType.class), VoyageAIEmbeddingType.FLOAT); + this.dimensionsSetByUser = in.readBoolean(); } public VoyageAIServiceSettings getCommonSettings() { @@ -177,6 +186,11 @@ public DenseVectorFieldMapper.ElementType elementType() { return embeddingType == null ? DenseVectorFieldMapper.ElementType.FLOAT : embeddingType.toElementType(); } + @Override + public Boolean dimensionsSetByUser() { + return this.dimensionsSetByUser; + } + @Override public String getWriteableName() { return NAME; @@ -222,6 +236,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalVInt(dimensions); out.writeOptionalVInt(maxInputTokens); out.writeOptionalEnum(embeddingType); + out.writeBoolean(dimensionsSetByUser); } @Override @@ -233,11 +248,12 @@ public boolean equals(Object o) { && Objects.equals(similarity, that.similarity) && Objects.equals(dimensions, that.dimensions) && Objects.equals(maxInputTokens, that.maxInputTokens) - && Objects.equals(embeddingType, that.embeddingType); + && Objects.equals(embeddingType, that.embeddingType) + && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser); } @Override public int hashCode() { - return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens, embeddingType); + return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens, embeddingType, dimensionsSetByUser); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java index fd6c0ee6c5002..57c478962b5f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.voyageai.rerank; +import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.ModelConfigurations; @@ -14,13 +15,17 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionVisitor; +import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel; import java.net.URI; +import java.net.URISyntaxException; import java.util.Map; +import static org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils.HOST; + public class VoyageAIRerankModel extends VoyageAIModel { public static VoyageAIRerankModel of(VoyageAIRerankModel model, Map taskSettings) { var requestTaskSettings = VoyageAIRerankTaskSettings.fromMap(taskSettings); @@ -51,12 +56,24 @@ public VoyageAIRerankModel( VoyageAIRerankServiceSettings serviceSettings, VoyageAIRerankTaskSettings taskSettings, @Nullable DefaultSecretSettings secretSettings + ) { + this(modelId, service, null, serviceSettings, taskSettings, secretSettings); + } + + VoyageAIRerankModel( + String modelId, + String service, + String url, + VoyageAIRerankServiceSettings serviceSettings, + VoyageAIRerankTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings ) { super( new ModelConfigurations(modelId, TaskType.RERANK, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings), secretSettings, - serviceSettings.getCommonSettings() + serviceSettings.getCommonSettings(), + url ); } @@ -96,7 +113,10 @@ public ExecutableAction accept(VoyageAIActionVisitor visitor, Map createAction("^^", "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender) - ); - MatcherAssert.assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); - } - } - public void testExecute_ThrowsElasticsearchException() { var sender = mock(Sender.class); doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); @@ -416,10 +405,7 @@ private ExecutableAction createAction( Sender sender ) { var model = VoyageAIEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024, modelName, embeddingType); - var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( - model.getServiceSettings().getCommonSettings().uri(), - "VoyageAI embeddings" - ); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "VoyageAI embeddings"); var requestCreator = VoyageAIEmbeddingsRequestManager.of(model, threadPool); return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java index 72216d6382968..5244ef83a7c7b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequestTests.java @@ -20,10 +20,10 @@ public class VoyageAIRequestTests extends ESTestCase { - public void testDecorateWithAuthHeader() { + public void testDecorateWithHeaders() { var request = new HttpPost("http://www.abc.com"); - VoyageAIRequest.decorateWithAuthHeader( + VoyageAIRequest.decorateWithHeaders( request, new VoyageAIAccount(URI.create("http://www.abc.com"), new SecureString(new char[] { 'a', 'b', 'c' })) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java index a81f07936c10a..93d92a34b4284 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.inference.external.response.voyageai; import org.apache.http.HttpResponse; -import org.elasticsearch.common.ParsingException; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest; @@ -39,9 +39,8 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { ] } ], - "model": "voyage-embeddings-v3", + "model": "voyage-3-large", "usage": { - "prompt_tokens": 8, "total_tokens": 8 } } @@ -85,9 +84,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException ] } ], - "model": "voyage-embeddings-v3", + "model": "voyage-3-large", "usage": { - "prompt_tokens": 8, "total_tokens": 8 } } @@ -128,9 +126,8 @@ public void testFromResponse_FailsWhenDataFieldIsNotPresent() { ] } ], - "model": "voyage-embeddings-v3", + "model": "voyage-3-large", "usage": { - "prompt_tokens": 8, "total_tokens": 8 } } @@ -142,14 +139,14 @@ public void testFromResponse_FailsWhenDataFieldIsNotPresent() { ); var thrownException = expectThrows( - IllegalStateException.class, + XContentParseException.class, () -> VoyageAIEmbeddingsResponseEntity.fromResponse( request, new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) ); - assertThat(thrownException.getMessage(), is("Failed to find required field [data] in VoyageAI embeddings response")); + assertThat(thrownException.getMessage(), is("[3:3] [EmbeddingFloatResult] unknown field [not_data]")); } public void testFromResponse_FailsWhenDataFieldNotAnArray() { @@ -166,9 +163,8 @@ public void testFromResponse_FailsWhenDataFieldNotAnArray() { ] } }, - "model": "voyage-embeddings-v3", + "model": "voyage-3-large", "usage": { - "prompt_tokens": 8, "total_tokens": 8 } } @@ -180,17 +176,14 @@ public void testFromResponse_FailsWhenDataFieldNotAnArray() { ); var thrownException = expectThrows( - ParsingException.class, + XContentParseException.class, () -> VoyageAIEmbeddingsResponseEntity.fromResponse( request, new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) ); - assertThat( - thrownException.getMessage(), - is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]") - ); + assertThat(thrownException.getMessage(), is("[4:15] [EmbeddingFloatResult] failed to parse field [data]")); } public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() { @@ -207,9 +200,8 @@ public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() { ] } ], - "model": "voyage-embeddings-v3", + "model": "voyage-3-large", "usage": { - "prompt_tokens": 8, "total_tokens": 8 } } @@ -221,14 +213,14 @@ public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() { ); var thrownException = expectThrows( - IllegalStateException.class, + XContentParseException.class, () -> VoyageAIEmbeddingsResponseEntity.fromResponse( request, new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) ); - assertThat(thrownException.getMessage(), is("Failed to find required field [embedding] in VoyageAI embeddings response")); + assertThat(thrownException.getMessage(), is("[7:27] [EmbeddingFloatResult] failed to parse field [data]")); } public void testFromResponse_FailsWhenEmbeddingValueIsAString() { @@ -244,9 +236,8 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAString() { ] } ], - "model": "voyage-embeddings-v3", + "model": "voyage-3-large", "usage": { - "prompt_tokens": 8, "total_tokens": 8 } } @@ -258,17 +249,14 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAString() { ); var thrownException = expectThrows( - ParsingException.class, + XContentParseException.class, () -> VoyageAIEmbeddingsResponseEntity.fromResponse( request, new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) ); - assertThat( - thrownException.getMessage(), - is("Failed to parse object: expecting token of type [VALUE_NUMBER] but found [VALUE_STRING]") - ); + assertThat(thrownException.getMessage(), is("[8:15] [EmbeddingFloatResult] failed to parse field [data]")); } public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOException { @@ -284,9 +272,8 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio ] } ], - "model": "voyage-embeddings-v3", + "model": "voyage-3-large", "usage": { - "prompt_tokens": 8, "total_tokens": 8 } } @@ -321,9 +308,8 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti ] } ], - "model": "voyage-embeddings-v3", + "model": "voyage-3-large", "usage": { - "prompt_tokens": 8, "total_tokens": 8 } } @@ -358,9 +344,8 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { ] } ], - "model": "voyage-embeddings-v3", + "model": "voyage-3-large", "usage": { - "prompt_tokens": 8, "total_tokens": 8 } } @@ -372,17 +357,14 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { ); var thrownException = expectThrows( - ParsingException.class, + XContentParseException.class, () -> VoyageAIEmbeddingsResponseEntity.fromResponse( request, new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) ); - assertThat( - thrownException.getMessage(), - is("Failed to parse object: expecting token of type [VALUE_NUMBER] but found [START_OBJECT]") - ); + assertThat(thrownException.getMessage(), is("[8:15] [EmbeddingFloatResult] failed to parse field [data]")); } public void testFieldsInDifferentOrderServer() throws IOException { @@ -390,7 +372,6 @@ public void testFieldsInDifferentOrderServer() throws IOException { String response = """ { "object": "list", - "id": "6667830b-716b-4796-9a61-33b67b5cc81d", "model": "voyage-3-large", "data": [ { @@ -420,8 +401,6 @@ public void testFieldsInDifferentOrderServer() throws IOException { } ], "usage": { - "prompt_tokens": 0, - "completion_tokens": 0, "total_tokens": 0 } }"""; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java index 1e7a89e5a406e..5c7aa7a80ad8f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntityTests.java @@ -28,6 +28,7 @@ public class VoyageAIRerankResponseEntityTests extends ESTestCase { public void testResponseLiteral() throws IOException { String responseLiteral = """ { + "object": "list", "model": "model", "data": [ { @@ -67,7 +68,7 @@ public void testGeneratedResponse() throws IOException { responseBuilder.append("{"); responseBuilder.append("\"model\": \"model\","); - responseBuilder.append("\"index\":\"").append(randomAlphaOfLength(36)).append("\","); + responseBuilder.append("\"object\": \"list\","); responseBuilder.append("\"data\": ["); List indices = linear(numDocs); List scores = linearFloats(numDocs); @@ -108,6 +109,7 @@ private ArrayList responseLiteralDocs() { public void testResponseLiteralWithDocuments() throws IOException { String responseLiteralWithDocuments = """ { + "object": "list", "model": "model", "data": [ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java index 2ee0c45e2d93b..09d890bd21f67 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java @@ -8,16 +8,12 @@ package org.elasticsearch.xpack.inference.services.voyageai; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.core.Nullable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.ServiceFields; -import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.MatcherAssert; @@ -26,48 +22,38 @@ import java.util.HashMap; import java.util.Map; -import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; public class VoyageAIServiceSettingsTests extends AbstractWireSerializingTestCase { public static VoyageAIServiceSettings createRandomWithNonNullUrl() { - return createRandom(randomAlphaOfLength(15)); + return createRandom(); } /** * The created settings can have a url set to null. */ public static VoyageAIServiceSettings createRandom() { - var url = randomBoolean() ? randomAlphaOfLength(15) : null; - return createRandom(url); - } - - private static VoyageAIServiceSettings createRandom(String url) { var model = randomAlphaOfLength(15); - return new VoyageAIServiceSettings(ServiceUtils.createOptionalUri(url), model, RateLimitSettingsTests.createRandom()); + return new VoyageAIServiceSettings(model, RateLimitSettingsTests.createRandom()); } public void testFromMap() { - var url = "https://www.abc.com"; var model = "model"; var serviceSettings = VoyageAIServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.URL, url, VoyageAIServiceSettings.MODEL_ID, model)), + new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, model)), ConfigurationParseContext.REQUEST ); - MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, null))); + MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(model, null))); } public void testFromMap_WithRateLimit() { - var url = "https://www.abc.com"; var model = "model"; var serviceSettings = VoyageAIServiceSettings.fromMap( new HashMap<>( Map.of( - ServiceFields.URL, - url, VoyageAIServiceSettings.MODEL_ID, model, RateLimitSettings.FIELD_NAME, @@ -77,65 +63,21 @@ public void testFromMap_WithRateLimit() { ConfigurationParseContext.REQUEST ); - MatcherAssert.assertThat( - serviceSettings, - is(new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, new RateLimitSettings(3))) - ); + MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(model, new RateLimitSettings(3)))); } public void testFromMap_WhenUsingModelId() { - var url = "https://www.abc.com"; var model = "model"; var serviceSettings = VoyageAIServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.URL, url, VoyageAIServiceSettings.MODEL_ID, model)), - ConfigurationParseContext.PERSISTENT - ); - - MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, null))); - } - - public void testFromMap_MissingUrl_DoesNotThrowException() { - var serviceSettings = VoyageAIServiceSettings.fromMap( - new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, "model")), + new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, model)), ConfigurationParseContext.PERSISTENT ); - assertNull(serviceSettings.uri()); - } - - public void testFromMap_EmptyUrl_ThrowsError() { - var thrownException = expectThrows( - ValidationException.class, - () -> VoyageAIServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "")), ConfigurationParseContext.PERSISTENT) - ); - - MatcherAssert.assertThat( - thrownException.getMessage(), - containsString( - Strings.format( - "Validation Failed: 1: [service_settings] Invalid value empty string. [%s] must be a non-empty string;", - ServiceFields.URL - ) - ) - ); - } - - public void testFromMap_InvalidUrl_ThrowsError() { - var url = "https://www.abc^.com"; - var thrownException = expectThrows( - ValidationException.class, - () -> VoyageAIServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)), ConfigurationParseContext.PERSISTENT) - ); - MatcherAssert.assertThat( - thrownException.getMessage(), - containsString( - Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", url, ServiceFields.URL) - ) - ); + MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(model, null))); } public void testXContent_WritesModelId() throws IOException { - var entity = new VoyageAIServiceSettings((String) null, "model", new RateLimitSettings(1)); + var entity = new VoyageAIServiceSettings("model", new RateLimitSettings(1)); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -160,13 +102,9 @@ protected VoyageAIServiceSettings mutateInstance(VoyageAIServiceSettings instanc return randomValueOtherThan(instance, VoyageAIServiceSettingsTests::createRandom); } - public static Map getServiceSettingsMap(@Nullable String url, String model) { + public static Map getServiceSettingsMap(String model) { var map = new HashMap(); - if (url != null) { - map.put(ServiceFields.URL, url); - } - map.put(VoyageAIServiceSettings.MODEL_ID, model); return map; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index e16e315770e8c..d45c78002fa4d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -52,6 +52,7 @@ import org.junit.Before; import java.io.IOException; +import java.net.URISyntaxException; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -107,7 +108,8 @@ public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModel() throws IOEx MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); @@ -117,7 +119,7 @@ public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModel() throws IOEx "id", TaskType.TEXT_EMBEDDING, getRequestConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), getSecretSettingsMap("secret") ), @@ -133,7 +135,8 @@ public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSe MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null))); MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); @@ -145,7 +148,7 @@ public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSe "id", TaskType.TEXT_EMBEDDING, getRequestConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), createRandomChunkingSettingsMap(), getSecretSettingsMap("secret") @@ -162,7 +165,8 @@ public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSe MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null))); MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); @@ -174,7 +178,7 @@ public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSe "id", TaskType.TEXT_EMBEDDING, getRequestConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), getSecretSettingsMap("secret") ), @@ -191,7 +195,8 @@ public void testParseRequestConfig_OptionalTaskSettings() throws IOException { MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), equalTo(VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); @@ -200,10 +205,7 @@ public void testParseRequestConfig_OptionalTaskSettings() throws IOException { service.parseRequestConfig( "id", TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), - getSecretSettingsMap("secret") - ), + getRequestConfigMap(VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), getSecretSettingsMap("secret")), modelListener ); @@ -221,7 +223,7 @@ public void testParseRequestConfig_ThrowsUnsupportedTaskType() throws IOExceptio "id", TaskType.SPARSE_EMBEDDING, getRequestConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), getSecretSettingsMap("secret") ), @@ -240,7 +242,7 @@ private static ActionListener getModelListenerForException(Class excep public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createVoyageAIService()) { var config = getRequestConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), getSecretSettingsMap("secret") ); @@ -256,7 +258,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { try (var service = createVoyageAIService()) { - var serviceSettings = VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"); + var serviceSettings = VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"); serviceSettings.put("extra_key", "value"); var config = getRequestConfigMap( @@ -279,7 +281,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() taskSettingsMap.put("extra_key", "value"); var config = getRequestConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), taskSettingsMap, getSecretSettingsMap("secret") ); @@ -299,7 +301,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap secretSettingsMap.put("extra_key", "value"); var config = getRequestConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), secretSettingsMap ); @@ -312,35 +314,10 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap } } - public void testParseRequestConfig_CreatesAVoyageAIEmbeddingsModelWithoutUrl() throws IOException { - try (var service = createVoyageAIService()) { - var modelListener = ActionListener.wrap((model) -> { - MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); - - var embeddingsModel = (VoyageAIEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri()); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, (e) -> fail("Model parsing should have succeeded " + e.getMessage())); - - service.parseRequestConfig( - "id", - TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"), - VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), - getSecretSettingsMap("secret") - ), - modelListener - ); - - } - } - public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModel() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), getSecretSettingsMap("secret") ); @@ -355,17 +332,20 @@ public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModel( MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } } public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), createRandomChunkingSettingsMap(), getSecretSettingsMap("secret") @@ -381,18 +361,21 @@ public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelW MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } } public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), getSecretSettingsMap("secret") ); @@ -407,18 +390,21 @@ public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelW MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } } public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "oldmodel"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("oldmodel"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), getSecretSettingsMap("secret") ); @@ -440,35 +426,10 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM } } - public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelWithoutUrl() throws IOException { - try (var service = createVoyageAIService()) { - var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"), - VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), - getSecretSettingsMap("secret") - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); - - var embeddingsModel = (VoyageAIEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri()); - MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null))); - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } - } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH), getSecretSettingsMap("secret") ); @@ -484,10 +445,13 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } } @@ -497,7 +461,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists secretSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), secretSettingsMap ); @@ -512,17 +476,20 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } } public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), getSecretSettingsMap("secret") ); @@ -538,16 +505,19 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } } public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createVoyageAIService()) { - var serviceSettingsMap = VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"); + var serviceSettingsMap = VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"); serviceSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( @@ -566,10 +536,13 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } } @@ -579,7 +552,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa taskSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), taskSettingsMap, getSecretSettingsMap("secret") ); @@ -594,17 +567,20 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } } public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModel() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) ); @@ -613,17 +589,20 @@ public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModel() throws IO MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); assertNull(embeddingsModel.getSecretSettings()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } } public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), createRandomChunkingSettingsMap() ); @@ -633,18 +612,21 @@ public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWhenChunking MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertNull(embeddingsModel.getSecretSettings()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } } public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) ); @@ -653,18 +635,21 @@ public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWhenChunking MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertNull(embeddingsModel.getSecretSettings()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } } public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model_old"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model_old"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty() ); @@ -683,7 +668,7 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWithoutUrl() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) ); @@ -692,7 +677,7 @@ public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWithoutUrl() MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri()); + assertNull(embeddingsModel.uri()); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))); assertNull(embeddingsModel.getSecretSettings()); @@ -702,7 +687,7 @@ public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWithoutUrl() public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty() ); persistedConfig.config().put("extra_key", "value"); @@ -712,16 +697,19 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); assertNull(embeddingsModel.getSecretSettings()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } } public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createVoyageAIService()) { - var serviceSettingsMap = VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"); + var serviceSettingsMap = VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"); serviceSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( @@ -734,10 +722,13 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettin MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null))); assertNull(embeddingsModel.getSecretSettings()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } } @@ -747,7 +738,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( taskSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( - VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), taskSettingsMap ); @@ -755,10 +746,13 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; - MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.uri()); + MatcherAssert.assertThat(embeddingsModel.buildUri().toString(), is("https://api.voyageai.com/v1/embeddings")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null))); assertNull(embeddingsModel.getSecretSettings()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } } @@ -805,11 +799,10 @@ public void testCheckModelConfig_UpdatesDimensions() throws IOException { String responseJson = """ { - "model": "voyage-clip-v2", + "model": "voyage-3-large", "object": "list", "usage": { - "total_tokens": 5, - "prompt_tokens": 5 + "total_tokens": 5 }, "data": [ { @@ -831,7 +824,7 @@ public void testCheckModelConfig_UpdatesDimensions() throws IOException { VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, 10, 1, - "voyage-clip-v2" + "voyage-3-large" ); PlainActionFuture listener = new PlainActionFuture<>(); service.checkModelConfig(model, listener); @@ -847,7 +840,7 @@ public void testCheckModelConfig_UpdatesDimensions() throws IOException { VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, 10, 2, - "voyage-clip-v2" + "voyage-3-large" ) ) ); @@ -861,11 +854,10 @@ public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() th String responseJson = """ { - "model": "voyage-clip-v2", + "model": "voyage-3-large", "object": "list", "usage": { - "total_tokens": 5, - "prompt_tokens": 5 + "total_tokens": 5 }, "data": [ { @@ -887,7 +879,7 @@ public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() th VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, 10, 1, - "voyage-clip-v2", + "voyage-3-large", (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); @@ -904,7 +896,7 @@ public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() th VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, 10, 2, - "voyage-clip-v2", + "voyage-3-large", SimilarityMeasure.DOT_PRODUCT ) ) @@ -919,11 +911,10 @@ public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosi String responseJson = """ { - "model": "voyage-clip-v2", + "model": "voyage-3-large", "object": "list", "usage": { - "total_tokens": 5, - "prompt_tokens": 5 + "total_tokens": 5 }, "data": [ { @@ -945,7 +936,7 @@ public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosi VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, 10, 1, - "voyage-clip-v2", + "voyage-3-large", SimilarityMeasure.COSINE ); PlainActionFuture listener = new PlainActionFuture<>(); @@ -962,7 +953,7 @@ public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosi VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, 10, 2, - "voyage-clip-v2", + "voyage-3-large", SimilarityMeasure.COSINE ) ) @@ -1082,11 +1073,10 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { String responseJson = """ { - "model": "voyage-clip-v2", + "model": "voyage-3-large", "object": "list", "usage": { - "total_tokens": 5, - "prompt_tokens": 5 + "total_tokens": 5 }, "data": [ { @@ -1108,7 +1098,7 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, 1024, 1024, - "voyage-clip-v2", + "voyage-3-large", (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1143,7 +1133,7 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { "input", List.of("abc"), "model", - "voyage-clip-v2", + "voyage-3-large", "input_type", "document", "output_dtype", @@ -1163,11 +1153,10 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { String responseJson = """ { - "model": "voyage-clip-v2", + "model": "voyage-3-large", "object": "list", "usage": { - "total_tokens": 5, - "prompt_tokens": 5 + "total_tokens": 5 }, "data": [ { @@ -1189,7 +1178,7 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, 1024, 1024, - "voyage-clip-v2", + "voyage-3-large", (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1224,7 +1213,7 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { "input", List.of("abc"), "model", - "voyage-clip-v2", + "voyage-3-large", "input_type", "query", "output_dtype", @@ -1243,7 +1232,7 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { String responseJson = """ - {"model":"voyage-clip-v2","object":"list","usage":{"total_tokens":5,"prompt_tokens":5}, + {"model":"voyage-3-large","object":"list","usage":{"total_tokens":5}, "data":[{"object":"embedding","index":0,"embedding":[0.123, -0.123]}]} """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); @@ -1254,7 +1243,7 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, 1024, 1024, - "voyage-clip-v2", + "voyage-3-large", (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1284,7 +1273,7 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of("input", List.of("abc"), "model", "voyage-clip-v2", "output_dtype", "float", "output_dimension", 1024)) + is(Map.of("input", List.of("abc"), "model", "voyage-3-large", "output_dtype", "float", "output_dimension", 1024)) ); } } @@ -1296,11 +1285,10 @@ public void testInfer_Embedding_Get_Response_NullInputType() throws IOException String responseJson = """ { - "model": "voyage-clip-v2", + "model": "voyage-3-large", "object": "list", "usage": { - "total_tokens": 5, - "prompt_tokens": 5 + "total_tokens": 5 }, "data": [ { @@ -1322,7 +1310,7 @@ public void testInfer_Embedding_Get_Response_NullInputType() throws IOException VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, 1024, 1024, - "voyage-clip-v2", + "voyage-3-large", (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1343,7 +1331,7 @@ public void testInfer_Embedding_Get_Response_NullInputType() throws IOException var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of("input", List.of("abc"), "model", "voyage-clip-v2", "output_dtype", "float", "output_dimension", 1024)) + is(Map.of("input", List.of("abc"), "model", "voyage-3-large", "output_dtype", "float", "output_dimension", 1024)) ); } } @@ -1697,11 +1685,10 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings String responseJson = """ { - "model": "voyage-clip-v2", + "model": "voyage-3-large", "object": "list", "usage": { - "total_tokens": 5, - "prompt_tokens": 5 + "total_tokens": 5 }, "data": [ { @@ -1723,7 +1710,7 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), 1024, 1024, - "voyage-clip-v2", + "voyage-3-large", (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1757,7 +1744,7 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings "input", List.of("abc"), "model", - "voyage-clip-v2", + "voyage-3-large", "input_type", "document", "output_dtype", @@ -1778,7 +1765,7 @@ public void test_Embedding_ChunkedInfer_BatchesCallsChunkingSettingsSet() throws createRandomChunkingSettings(), 1024, 1024, - "voyage-clip-v2" + "voyage-3-large" ); test_Embedding_ChunkedInfer_BatchesCalls(model); @@ -1792,7 +1779,7 @@ public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOExcept null, 1024, 1024, - "voyage-clip-v2" + "voyage-3-large" ); test_Embedding_ChunkedInfer_BatchesCalls(model); @@ -1806,11 +1793,10 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo // Batching will call the service with 2 input String responseJson = """ { - "model": "voyage-clip-v2", + "model": "voyage-3-large", "object": "list", "usage": { - "total_tokens": 5, - "prompt_tokens": 5 + "total_tokens": 5 }, "data": [ { @@ -1874,7 +1860,7 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of("input", List.of("foo", "bar"), "model", "voyage-clip-v2", "output_dtype", "float", "output_dimension", 1024)) + is(Map.of("input", List.of("foo", "bar"), "model", "voyage-3-large", "output_dtype", "float", "output_dimension", 1024)) ); } } @@ -1892,6 +1878,15 @@ public void testGetConfiguration() throws Exception { "name": "Voyage AI", "task_types": ["text_embedding", "rerank"], "configurations": { + "model_id": { + "description": "The name of the model to use for the inference task.", + "label": "Model ID", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "rerank"] + }, "api_key": { "description": "API Key for the provider you're connecting to.", "label": "API Key", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java index 32a03a26f0323..e497e606b0689 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsModelTests.java @@ -112,12 +112,14 @@ public static VoyageAIEmbeddingsModel createModel( return new VoyageAIEmbeddingsModel( "id", "service", + url, new VoyageAIEmbeddingsServiceSettings( - new VoyageAIServiceSettings(url, model, null), + new VoyageAIServiceSettings(model, null), VoyageAIEmbeddingType.FLOAT, SimilarityMeasure.DOT_PRODUCT, dimensions, - tokenLimit + tokenLimit, + false ), taskSettings, chunkingSettings, @@ -136,12 +138,14 @@ public static VoyageAIEmbeddingsModel createModel( return new VoyageAIEmbeddingsModel( "id", "service", + url, new VoyageAIEmbeddingsServiceSettings( - new VoyageAIServiceSettings(url, model, null), + new VoyageAIServiceSettings(model, null), VoyageAIEmbeddingType.FLOAT, SimilarityMeasure.DOT_PRODUCT, dimensions, - tokenLimit + tokenLimit, + false ), taskSettings, null, @@ -161,12 +165,14 @@ public static VoyageAIEmbeddingsModel createModel( return new VoyageAIEmbeddingsModel( "id", "service", + url, new VoyageAIEmbeddingsServiceSettings( - new VoyageAIServiceSettings(url, model, null), + new VoyageAIServiceSettings(model, null), embeddingType, SimilarityMeasure.DOT_PRODUCT, dimensions, - tokenLimit + tokenLimit, + false ), taskSettings, null, @@ -186,12 +192,14 @@ public static VoyageAIEmbeddingsModel createModel( return new VoyageAIEmbeddingsModel( "id", "service", + url, new VoyageAIEmbeddingsServiceSettings( - new VoyageAIServiceSettings(url, model, null), + new VoyageAIServiceSettings(model, null), VoyageAIEmbeddingType.FLOAT, similarityMeasure, dimensions, - tokenLimit + tokenLimit, + false ), taskSettings, null, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java index c4980e02ca42d..0bedd0d60c25c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettingsTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; @@ -21,7 +20,6 @@ import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; -import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettingsTests; @@ -33,23 +31,29 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER; import static org.hamcrest.Matchers.is; public class VoyageAIEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { public static VoyageAIEmbeddingsServiceSettings createRandom() { - SimilarityMeasure similarityMeasure = null; - Integer dims = null; - similarityMeasure = SimilarityMeasure.DOT_PRODUCT; - dims = 1024; + SimilarityMeasure similarityMeasure = SimilarityMeasure.DOT_PRODUCT; + Integer dims = 1024; Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); + Boolean dimensionSetByUser = randomBoolean(); var commonSettings = VoyageAIServiceSettingsTests.createRandom(); - return new VoyageAIEmbeddingsServiceSettings(commonSettings, VoyageAIEmbeddingType.FLOAT, similarityMeasure, dims, maxInputTokens); + return new VoyageAIEmbeddingsServiceSettings( + commonSettings, + VoyageAIEmbeddingType.FLOAT, + similarityMeasure, + dims, + maxInputTokens, + dimensionSetByUser + ); } public void testFromMap() { - var url = "https://www.abc.com"; var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); var dims = 1536; var maxInputTokens = 512; @@ -57,8 +61,6 @@ public void testFromMap() { var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap( new HashMap<>( Map.of( - ServiceFields.URL, - url, ServiceFields.SIMILARITY, similarity, ServiceFields.DIMENSIONS, @@ -76,18 +78,51 @@ public void testFromMap() { serviceSettings, is( new VoyageAIEmbeddingsServiceSettings( - new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, null), + new VoyageAIServiceSettings(model, null), VoyageAIEmbeddingType.FLOAT, SimilarityMeasure.DOT_PRODUCT, dims, - maxInputTokens + maxInputTokens, + false ) ) ); } public void testFromMap_WithModelId() { - var url = "https://www.abc.com"; + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var maxInputTokens = 512; + var model = "model"; + var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + VoyageAIServiceSettings.MODEL_ID, + model + ) + ), + ConfigurationParseContext.REQUEST + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(model, null), + VoyageAIEmbeddingType.FLOAT, + SimilarityMeasure.DOT_PRODUCT, + null, + maxInputTokens, + false + ) + ) + ); + } + + public void testFromMap_WithModelId_WithDimensions() { var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); var dims = 1536; var maxInputTokens = 512; @@ -95,8 +130,6 @@ public void testFromMap_WithModelId() { var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap( new HashMap<>( Map.of( - ServiceFields.URL, - url, ServiceFields.SIMILARITY, similarity, ServiceFields.DIMENSIONS, @@ -114,11 +147,83 @@ public void testFromMap_WithModelId() { serviceSettings, is( new VoyageAIEmbeddingsServiceSettings( - new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, null), + new VoyageAIServiceSettings(model, null), VoyageAIEmbeddingType.FLOAT, SimilarityMeasure.DOT_PRODUCT, dims, - maxInputTokens + maxInputTokens, + true + ) + ) + ); + } + + public void testFromMap_DimensionsSetByUserIsFalseInRequestContext() { + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var maxInputTokens = 512; + var model = "model"; + var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + DIMENSIONS_SET_BY_USER, + true, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + VoyageAIServiceSettings.MODEL_ID, + model + ) + ), + ConfigurationParseContext.REQUEST + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(model, null), + VoyageAIEmbeddingType.FLOAT, + SimilarityMeasure.DOT_PRODUCT, + null, + maxInputTokens, + false + ) + ) + ); + } + + public void testFromMap_DimensionsSetByUserIsSetInPersistentContext() { + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var maxInputTokens = 512; + var model = "model"; + var dimensionsSetByUser = randomBoolean(); + var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + DIMENSIONS_SET_BY_USER, + dimensionsSetByUser, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + VoyageAIServiceSettings.MODEL_ID, + model + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings(model, null), + VoyageAIEmbeddingType.FLOAT, + SimilarityMeasure.DOT_PRODUCT, + null, + maxInputTokens, + dimensionsSetByUser ) ) ); @@ -146,11 +251,37 @@ public void testFromMap_InvalidSimilarity_ThrowsError() { @SuppressWarnings("checkstyle:LineLength") public void testToXContent_WritesAllValues() throws IOException { var serviceSettings = new VoyageAIEmbeddingsServiceSettings( - new VoyageAIServiceSettings("url", "model", new RateLimitSettings(3)), + new VoyageAIServiceSettings("model", new RateLimitSettings(3)), + VoyageAIEmbeddingType.FLOAT, + SimilarityMeasure.COSINE, + 5, + 10, + false + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + assertThat( + xContentResult, + is( + """ + {"model_id":"model",""" + + """ + "rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10,"embedding_type":"float"}""" + ) + ); + } + + @SuppressWarnings("checkstyle:LineLength") + public void testToXContent_WritesAllValues_DimensionSetByUser() throws IOException { + var serviceSettings = new VoyageAIEmbeddingsServiceSettings( + new VoyageAIServiceSettings("model", new RateLimitSettings(3)), VoyageAIEmbeddingType.FLOAT, SimilarityMeasure.COSINE, 5, - 10 + 10, + true ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -160,7 +291,7 @@ public void testToXContent_WritesAllValues() throws IOException { xContentResult, is( """ - {"url":"url","model_id":"model",""" + {"model_id":"model",""" + """ "rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10,"embedding_type":"float"}""" ) @@ -190,7 +321,7 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { return new NamedWriteableRegistry(entries); } - public static Map getServiceSettingsMap(@Nullable String url, String model) { - return new HashMap<>(VoyageAIServiceSettingsTests.getServiceSettingsMap(url, model)); + public static Map getServiceSettingsMap(String model) { + return new HashMap<>(VoyageAIServiceSettingsTests.getServiceSettingsMap(model)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java index bca6cc94d9afd..d9ceca107f9e4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java @@ -18,7 +18,8 @@ public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nu return new VoyageAIRerankModel( "id", "service", - new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), + ESTestCase.randomAlphaOfLength(10), + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)), new VoyageAIRerankTaskSettings(topK, null, truncation), new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); @@ -28,7 +29,8 @@ public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nu return new VoyageAIRerankModel( "id", "service", - new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), + ESTestCase.randomAlphaOfLength(10), + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)), new VoyageAIRerankTaskSettings(topK, null, null), new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); @@ -38,7 +40,8 @@ public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer return new VoyageAIRerankModel( "id", "service", - new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), + ESTestCase.randomAlphaOfLength(10), + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)), new VoyageAIRerankTaskSettings(topK, null, null), new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) ); @@ -48,7 +51,8 @@ public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer return new VoyageAIRerankModel( "id", "service", - new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), + ESTestCase.randomAlphaOfLength(10), + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)), new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation), new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) ); @@ -64,7 +68,8 @@ public static VoyageAIRerankModel createModel( return new VoyageAIRerankModel( "id", "service", - new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(url, modelId, null)), + url, + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)), new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation), new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) ); @@ -81,7 +86,8 @@ public static VoyageAIRerankModel createModel( return new VoyageAIRerankModel( "id", "service", - new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(url, modelId, null)), + url, + new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)), new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation), new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java index 429be4a2c31d7..7891d5cc7cca0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankServiceSettingsTests.java @@ -28,11 +28,7 @@ public class VoyageAIRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase { public static VoyageAIRerankServiceSettings createRandom() { return new VoyageAIRerankServiceSettings( - new VoyageAIServiceSettings( - randomFrom(new String[] { null, Strings.format("http://%s.com", randomAlphaOfLength(8)) }), - randomAlphaOfLength(10), - RateLimitSettingsTests.createRandom() - ) + new VoyageAIServiceSettings(randomAlphaOfLength(10), RateLimitSettingsTests.createRandom()) ); } @@ -40,7 +36,7 @@ public void testToXContent_WritesAllValues() throws IOException { var url = "http://www.abc.com"; var model = "model"; - var serviceSettings = new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(url, model, null)); + var serviceSettings = new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(model, null)); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); @@ -48,7 +44,6 @@ public void testToXContent_WritesAllValues() throws IOException { assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" { - "url":"http://www.abc.com", "model_id":"model", "rate_limit": { "requests_per_minute": 2000 @@ -77,7 +72,7 @@ protected VoyageAIRerankServiceSettings mutateInstanceForVersion(VoyageAIRerankS return instance; } - public static Map getServiceSettingsMap(@Nullable String url, @Nullable String model) { - return new HashMap<>(VoyageAIServiceSettingsTests.getServiceSettingsMap(url, model)); + public static Map getServiceSettingsMap(@Nullable String model) { + return new HashMap<>(VoyageAIServiceSettingsTests.getServiceSettingsMap(model)); } } From 850ab5dc799077263a9adbcd3e25be54af8c7f44 Mon Sep 17 00:00:00 2001 From: fzowl Date: Tue, 18 Feb 2025 18:37:14 +0100 Subject: [PATCH 18/20] Changes due to the comments --- .../org/elasticsearch/TransportVersions.java | 2 +- .../VoyageAIEmbeddingsResponseEntity.java | 42 ++++++------------- .../VoyageAIRerankResponseEntity.java | 20 ++------- .../VoyageAIEmbeddingsServiceSettings.java | 6 +-- ...VoyageAIEmbeddingsResponseEntityTests.java | 9 ++-- 5 files changed, 24 insertions(+), 55 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 4ff53f02a23fb..9b7f37ddfa894 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -197,7 +197,7 @@ static TransportVersion def(int id) { public static final TransportVersion SLM_UNHEALTHY_IF_NO_SNAPSHOT_WITHIN = def(9_010_0_00); public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS = def(9_011_0_00); - public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_008_0_00); + public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_012_0_00); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java index 472b277e7e0d8..f2c6cb5db4c32 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java @@ -10,7 +10,6 @@ package org.elasticsearch.xpack.inference.external.response.voyageai; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -32,7 +31,6 @@ import java.util.List; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; -import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType.toLowerCase; public class VoyageAIEmbeddingsResponseEntity { @@ -47,31 +45,29 @@ private static String supportedEmbeddingTypes() { return String.join(", ", validTypes); } - record EmbeddingInt8Result(List entries, String model, String object, @Nullable Usage usage) { + record EmbeddingInt8Result(List entries) { @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( EmbeddingInt8Result.class.getSimpleName(), - args -> new EmbeddingInt8Result((List) args[0], (String) args[1], (String) args[2], (Usage) args[3]) + true, + args -> new EmbeddingInt8Result((List) args[0]) ); static { PARSER.declareObjectArray(constructorArg(), EmbeddingInt8ResultEntry.PARSER::apply, new ParseField("data")); - PARSER.declareString(constructorArg(), new ParseField("model")); - PARSER.declareString(constructorArg(), new ParseField("object")); - PARSER.declareObject(optionalConstructorArg(), Usage.PARSER::apply, new ParseField("usage")); } } - record EmbeddingInt8ResultEntry(String object, Integer index, List embedding) { + record EmbeddingInt8ResultEntry(Integer index, List embedding) { @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( EmbeddingInt8ResultEntry.class.getSimpleName(), - args -> new EmbeddingInt8ResultEntry((String) args[0], (Integer) args[1], (List) args[2]) + true, + args -> new EmbeddingInt8ResultEntry((Integer) args[0], (List) args[1]) ); static { - PARSER.declareString(constructorArg(), new ParseField("object")); PARSER.declareInt(constructorArg(), new ParseField("index")); PARSER.declareIntArray(constructorArg(), new ParseField("embedding")); } @@ -88,31 +84,29 @@ public InferenceByteEmbedding toInferenceByteEmbedding() { } } - record EmbeddingFloatResult(List entries, String model, String object, @Nullable Usage usage) { + record EmbeddingFloatResult(List entries) { @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( EmbeddingFloatResult.class.getSimpleName(), - args -> new EmbeddingFloatResult((List) args[0], (String) args[1], (String) args[2], (Usage) args[3]) + true, + args -> new EmbeddingFloatResult((List) args[0]) ); static { PARSER.declareObjectArray(constructorArg(), EmbeddingFloatResultEntry.PARSER::apply, new ParseField("data")); - PARSER.declareString(constructorArg(), new ParseField("model")); - PARSER.declareString(constructorArg(), new ParseField("object")); - PARSER.declareObject(optionalConstructorArg(), Usage.PARSER::apply, new ParseField("usage")); } } - record EmbeddingFloatResultEntry(String object, Integer index, List embedding) { + record EmbeddingFloatResultEntry(Integer index, List embedding) { @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( EmbeddingFloatResultEntry.class.getSimpleName(), - args -> new EmbeddingFloatResultEntry((String) args[0], (Integer) args[1], (List) args[2]) + true, + args -> new EmbeddingFloatResultEntry((Integer) args[0], (List) args[1]) ); static { - PARSER.declareString(constructorArg(), new ParseField("object")); PARSER.declareInt(constructorArg(), new ParseField("index")); PARSER.declareFloatArray(constructorArg(), new ParseField("embedding")); } @@ -122,18 +116,6 @@ public InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding toInferenceFlo } } - record Usage(Integer totalTokens) { - - public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - Usage.class.getSimpleName(), - args -> new Usage((Integer) args[0]) - ); - - static { - PARSER.declareInt(constructorArg(), new ParseField("total_tokens")); - } - } - /** * Parses the VoyageAI json response. * For a request like: diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java index e0ef73bd2c7ae..5438ba3644753 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java @@ -33,19 +33,17 @@ public class VoyageAIRerankResponseEntity { private static final Logger logger = LogManager.getLogger(VoyageAIRerankResponseEntity.class); - record RerankResult(List entries, String model, String object, @Nullable Usage usage) { + record RerankResult(List entries) { @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( RerankResult.class.getSimpleName(), - args -> new RerankResult((List) args[0], (String) args[1], (String) args[2], (Usage) args[3]) + true, + args -> new RerankResult((List) args[0]) ); static { PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("data")); - PARSER.declareString(constructorArg(), new ParseField("model")); - PARSER.declareString(constructorArg(), new ParseField("object")); - PARSER.declareObject(optionalConstructorArg(), Usage.PARSER::apply, new ParseField("usage")); } } @@ -67,18 +65,6 @@ public RankedDocsResults.RankedDoc toRankedDoc() { } } - record Usage(Integer totalTokens) { - - public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - Usage.class.getSimpleName(), - args -> new Usage((Integer) args[0]) - ); - - static { - PARSER.declareInt(constructorArg(), new ParseField("total_tokens")); - } - } - /** * Parses the VoyageAI ranked response. * For a request like: diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java index 15cbe43bc755a..cc4db278d0e2b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java @@ -127,7 +127,7 @@ static VoyageAIEmbeddingType parseEmbeddingType( private final SimilarityMeasure similarity; private final Integer dimensions; private final Integer maxInputTokens; - private final Boolean dimensionsSetByUser; + private final boolean dimensionsSetByUser; public VoyageAIEmbeddingsServiceSettings( VoyageAIServiceSettings commonSettings, @@ -135,14 +135,14 @@ public VoyageAIEmbeddingsServiceSettings( @Nullable SimilarityMeasure similarity, @Nullable Integer dimensions, @Nullable Integer maxInputTokens, - Boolean dimensionsSetByUser + boolean dimensionsSetByUser ) { this.commonSettings = commonSettings; this.similarity = similarity; this.dimensions = dimensions; this.maxInputTokens = maxInputTokens; this.embeddingType = embeddingType; - this.dimensionsSetByUser = Objects.requireNonNull(dimensionsSetByUser); + this.dimensionsSetByUser = dimensionsSetByUser; } public VoyageAIEmbeddingsServiceSettings(StreamInput in) throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java index 93d92a34b4284..2b1c8fa43af53 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java @@ -20,6 +20,7 @@ import java.util.List; import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests.createModel; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -139,14 +140,14 @@ public void testFromResponse_FailsWhenDataFieldIsNotPresent() { ); var thrownException = expectThrows( - XContentParseException.class, + java.lang.IllegalArgumentException.class, () -> VoyageAIEmbeddingsResponseEntity.fromResponse( request, new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) ); - assertThat(thrownException.getMessage(), is("[3:3] [EmbeddingFloatResult] unknown field [not_data]")); + assertThat(thrownException.getMessage(), is("Required [data]")); } public void testFromResponse_FailsWhenDataFieldNotAnArray() { @@ -183,7 +184,7 @@ public void testFromResponse_FailsWhenDataFieldNotAnArray() { ) ); - assertThat(thrownException.getMessage(), is("[4:15] [EmbeddingFloatResult] failed to parse field [data]")); + assertThat(thrownException.getMessage(), containsString("[EmbeddingFloatResult] failed to parse field [data]")); } public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() { @@ -220,7 +221,7 @@ public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() { ) ); - assertThat(thrownException.getMessage(), is("[7:27] [EmbeddingFloatResult] failed to parse field [data]")); + assertThat(thrownException.getMessage(), containsString("[EmbeddingFloatResult] failed to parse field [data]")); } public void testFromResponse_FailsWhenEmbeddingValueIsAString() { From d6042e7e5a8fb1cc1254ed5d3da715285664edde Mon Sep 17 00:00:00 2001 From: fzowl Date: Wed, 19 Feb 2025 13:52:30 +0100 Subject: [PATCH 19/20] Correcting QA tests --- .../inference/InferenceGetServicesIT.java | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 9d4cec798964a..29d2b7e375788 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { @SuppressWarnings("unchecked") public void testGetServicesWithoutTaskType() throws IOException { List services = getAllServices(); - assertThat(services.size(), equalTo(19)); + assertThat(services.size(), equalTo(20)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -53,6 +53,7 @@ public void testGetServicesWithoutTaskType() throws IOException { "test_reranking_service", "test_service", "text_embedding_test_service", + "voyageai", "watsonxai" ).toArray(), providers @@ -62,7 +63,7 @@ public void testGetServicesWithoutTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithTextEmbeddingTaskType() throws IOException { List services = getServices(TaskType.TEXT_EMBEDDING); - assertThat(services.size(), equalTo(14)); + assertThat(services.size(), equalTo(15)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -85,6 +86,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { "mistral", "openai", "text_embedding_test_service", + "voyageai", "watsonxai" ).toArray(), providers @@ -94,7 +96,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithRerankTaskType() throws IOException { List services = getServices(TaskType.RERANK); - assertThat(services.size(), equalTo(6)); + assertThat(services.size(), equalTo(7)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -103,7 +105,15 @@ public void testGetServicesWithRerankTaskType() throws IOException { } assertArrayEquals( - List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service").toArray(), + List.of( + "alibabacloud-ai-search", + "cohere", + "elasticsearch", + "googlevertexai", + "jinaai", + "test_reranking_service", + "voyageai" + ).toArray(), providers ); } From 1ebe1a46f32549bcb7aaf8bd52a39b37cb87e8b1 Mon Sep 17 00:00:00 2001 From: fzowl Date: Wed, 19 Feb 2025 23:02:22 +0100 Subject: [PATCH 20/20] Correcting QA tests --- .../xpack/inference/InferenceGetServicesIT.java | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 29d2b7e375788..859a065b6e1a0 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -105,15 +105,8 @@ public void testGetServicesWithRerankTaskType() throws IOException { } assertArrayEquals( - List.of( - "alibabacloud-ai-search", - "cohere", - "elasticsearch", - "googlevertexai", - "jinaai", - "test_reranking_service", - "voyageai" - ).toArray(), + List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai") + .toArray(), providers ); }