Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/122134.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 122134
summary: Adding integration for VoyageAI embeddings and rerank models
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ static TransportVersion def(int id) {
public static final TransportVersion REMOVE_ALL_APPLICABLE_SELECTOR_BACKPORT_8_19 = def(8_841_0_02);
public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE_BACKPORT_8_19 = def(8_841_0_03);
public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS_BACKPORT_8_19 = def(8_841_0_04);
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
@SuppressWarnings("unchecked")
public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(19));
assertThat(services.size(), equalTo(20));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand Down Expand Up @@ -54,6 +54,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"test_reranking_service",
"test_service",
"text_embedding_test_service",
"voyageai",
"watsonxai"
).toArray(),
providers
Expand All @@ -63,7 +64,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(14));
assertThat(services.size(), equalTo(15));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -86,6 +87,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
"mistral",
"openai",
"text_embedding_test_service",
"voyageai",
"watsonxai"
).toArray(),
providers
Expand All @@ -95,7 +97,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(6));
assertThat(services.size(), equalTo(7));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -104,7 +106,8 @@ public void testGetServicesWithRerankTaskType() throws IOException {
}

assertArrayEquals(
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service").toArray(),
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai")
.toArray(),
providers
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -143,6 +148,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addEisNamedWriteables(namedWriteables);
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);

addUnifiedNamedWriteables(namedWriteables);

Expand Down Expand Up @@ -619,6 +625,28 @@ private static void addJinaAINamedWriteables(List<NamedWriteableRegistry.Entry>
);
}

private static void addVoyageAINamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIServiceSettings.NAME, VoyageAIServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
VoyageAIEmbeddingsServiceSettings.NAME,
VoyageAIEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIEmbeddingsTaskSettings.NAME, VoyageAIEmbeddingsTaskSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIRerankServiceSettings.NAME, VoyageAIRerankServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIRerankTaskSettings.NAME, VoyageAIRerankTaskSettings::new)
);
}

private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

import java.util.ArrayList;
Expand Down Expand Up @@ -359,6 +360,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.voyageai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIRerankRequestManager;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;

import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;

/**
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the voyageai model type.
*/
public class VoyageAIActionCreator implements VoyageAIActionVisitor {
private final Sender sender;
private final ServiceComponents serviceComponents;

public VoyageAIActionCreator(Sender sender, ServiceComponents serviceComponents) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}

@Override
public ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = VoyageAIEmbeddingsModel.of(model, taskSettings, inputType);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("VoyageAI embeddings");
var requestCreator = VoyageAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}

@Override
public ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = VoyageAIRerankModel.of(model, taskSettings);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("VoyageAI rerank");
var requestCreator = VoyageAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.voyageai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;

import java.util.Map;

public interface VoyageAIActionVisitor {
ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);

ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

public class VoyageAIEmbeddingsRequestManager extends VoyageAIRequestManager {
private static final Logger logger = LogManager.getLogger(VoyageAIEmbeddingsRequestManager.class);
private static final ResponseHandler HANDLER = createEmbeddingsHandler();

private static ResponseHandler createEmbeddingsHandler() {
return new VoyageAIResponseHandler("voyageai text embedding", VoyageAIEmbeddingsResponseEntity::fromResponse);
}

public static VoyageAIEmbeddingsRequestManager of(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
return new VoyageAIEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}

private final VoyageAIEmbeddingsModel model;

private VoyageAIEmbeddingsRequestManager(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = Objects.requireNonNull(model);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(docsInput, model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;

import java.util.Map;
import java.util.Objects;

abstract class VoyageAIRequestManager extends BaseRequestManager {
private static final String DEFAULT_MODEL_FAMILY = "default_model_family";
private static final Map<String, String> MODEL_TO_MODEL_FAMILY = Map.of(
"voyage-multimodal-3",
"embed_multimodal",
"voyage-3-large",
"embed_large",
"voyage-code-3",
"embed_large",
"voyage-3",
"embed_medium",
"voyage-3-lite",
"embed_small",
"voyage-finance-2",
"embed_large",
"voyage-law-2",
"embed_large",
"voyage-code-2",
"embed_large",
"rerank-2",
"rerank_large",
"rerank-2-lite",
"rerank_small"
);

protected VoyageAIRequestManager(ThreadPool threadPool, VoyageAIModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
}

record RateLimitGrouping(int apiKeyHash) {
public static RateLimitGrouping of(VoyageAIModel model) {
Objects.requireNonNull(model);
String modelId = model.getServiceSettings().modelId();
String modelFamily = MODEL_TO_MODEL_FAMILY.getOrDefault(modelId, DEFAULT_MODEL_FAMILY);

return new RateLimitGrouping(modelFamily.hashCode());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequest;
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIRerankResponseEntity;
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;

import java.util.Objects;
import java.util.function.Supplier;

public class VoyageAIRerankRequestManager extends VoyageAIRequestManager {
private static final Logger logger = LogManager.getLogger(VoyageAIRerankRequestManager.class);
private static final ResponseHandler HANDLER = createVoyageAIResponseHandler();

private static ResponseHandler createVoyageAIResponseHandler() {
return new VoyageAIResponseHandler("voyageai rerank", (request, response) -> VoyageAIRerankResponseEntity.fromResponse(response));
}

public static VoyageAIRerankRequestManager of(VoyageAIRerankModel model, ThreadPool threadPool) {
return new VoyageAIRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}

private final VoyageAIRerankModel model;

private VoyageAIRerankRequestManager(VoyageAIRerankModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = model;
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
VoyageAIRerankRequest request = new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Loading