Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0535410
VoyageAI embeddings and rerank:
fzowl Feb 3, 2025
07c39a0
VoyageAI embeddings and rerank:
fzowl Feb 3, 2025
91dee7f
VoyageAI embeddings and rerank:
fzowl Feb 3, 2025
6f11414
VoyageAI embeddings and rerank:
fzowl Feb 3, 2025
050d5b2
Adding initial tests
fzowl Feb 4, 2025
b94bc5f
Correcting the TransportVersions.java
fzowl Feb 5, 2025
be1e9cf
Correcting due to comments
fzowl Feb 5, 2025
71dfdc8
Adding BIT support
fzowl Feb 5, 2025
8f6e03b
Initial tests
fzowl Feb 7, 2025
d41538a
More tests
fzowl Feb 8, 2025
3f1a75a
More tests/corrections
fzowl Feb 9, 2025
a89fdf1
Removing warnings
fzowl Feb 9, 2025
b7cb871
Further tests
fzowl Feb 9, 2025
b9681af
Transport version correction
fzowl Feb 9, 2025
bcca709
Merge pull request #2 from voyage-ai/voyageai
fzowl Feb 9, 2025
c211583
Adding changelog and correcting TransportVersions
fzowl Feb 9, 2025
a14ea32
Merge branch 'main' into main
fzowl Feb 10, 2025
91de4c3
Spotless tests
fzowl Feb 11, 2025
cb9fd17
Changes due to the comments
fzowl Feb 16, 2025
01c7632
Merge branch 'main' into main
fzowl Feb 16, 2025
850ab5d
Changes due to the comments
fzowl Feb 18, 2025
d6042e7
Correcting QA tests
fzowl Feb 19, 2025
1ebe1a4
Correcting QA tests
fzowl Feb 19, 2025
8796053
Merge branch 'main' of github.com:elastic/elasticsearch into voyage-a…
jonathan-buttner Feb 20, 2025
0a5b7c0
Merge branch 'main' into main
fzowl Feb 20, 2025
29fd189
Merge branch 'main' into main
fzowl Feb 20, 2025
baf9e3a
Merge branch 'main' into main
fzowl Feb 20, 2025
2a27c1d
Merge branch 'main' into main
jonathan-buttner Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/122134.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 122134
summary: Adding integration for VoyageAI embeddings and rerank models
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ static TransportVersion def(int id) {
public static final TransportVersion REMOVE_ALL_APPLICABLE_SELECTOR_BACKPORT_8_19 = def(8_841_0_02);
public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE_BACKPORT_8_19 = def(8_841_0_03);
public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS_BACKPORT_8_19 = def(8_841_0_04);
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
Expand All @@ -199,7 +200,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS = def(9_011_0_00);
public static final TransportVersion REMOVE_REPOSITORY_CONFLICT_MESSAGE = def(9_012_0_00);
public static final TransportVersion RERANKER_FAILURES_ALLOWED = def(9_013_0_00);

public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_014_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
@SuppressWarnings("unchecked")
public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(19));
assertThat(services.size(), equalTo(20));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand Down Expand Up @@ -53,6 +53,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"test_reranking_service",
"test_service",
"text_embedding_test_service",
"voyageai",
"watsonxai"
).toArray(),
providers
Expand All @@ -62,7 +63,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(14));
assertThat(services.size(), equalTo(15));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -85,6 +86,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
"mistral",
"openai",
"text_embedding_test_service",
"voyageai",
"watsonxai"
).toArray(),
providers
Expand All @@ -94,7 +96,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(6));
assertThat(services.size(), equalTo(7));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -103,7 +105,8 @@ public void testGetServicesWithRerankTaskType() throws IOException {
}

assertArrayEquals(
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service").toArray(),
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai")
.toArray(),
providers
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -142,6 +147,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addEisNamedWriteables(namedWriteables);
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);

addUnifiedNamedWriteables(namedWriteables);

Expand Down Expand Up @@ -626,6 +632,28 @@ private static void addJinaAINamedWriteables(List<NamedWriteableRegistry.Entry>
);
}

private static void addVoyageAINamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIServiceSettings.NAME, VoyageAIServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
VoyageAIEmbeddingsServiceSettings.NAME,
VoyageAIEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIEmbeddingsTaskSettings.NAME, VoyageAIEmbeddingsTaskSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIRerankServiceSettings.NAME, VoyageAIRerankServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIRerankTaskSettings.NAME, VoyageAIRerankTaskSettings::new)
);
}

private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

import java.util.ArrayList;
Expand Down Expand Up @@ -359,6 +360,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.voyageai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIRerankRequestManager;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;

import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;

/**
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the voyageai model type.
*/
public class VoyageAIActionCreator implements VoyageAIActionVisitor {
private final Sender sender;
private final ServiceComponents serviceComponents;

public VoyageAIActionCreator(Sender sender, ServiceComponents serviceComponents) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}

@Override
public ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = VoyageAIEmbeddingsModel.of(model, taskSettings, inputType);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI embeddings");
var requestCreator = VoyageAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}

@Override
public ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = VoyageAIRerankModel.of(model, taskSettings);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI rerank");
var requestCreator = VoyageAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.voyageai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;

import java.util.Map;

public interface VoyageAIActionVisitor {
ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);

ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

public class VoyageAIEmbeddingsRequestManager extends VoyageAIRequestManager {
private static final Logger logger = LogManager.getLogger(VoyageAIEmbeddingsRequestManager.class);
private static final ResponseHandler HANDLER = createEmbeddingsHandler();

private static ResponseHandler createEmbeddingsHandler() {
return new VoyageAIResponseHandler("voyageai text embedding", VoyageAIEmbeddingsResponseEntity::fromResponse);
}

public static VoyageAIEmbeddingsRequestManager of(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
return new VoyageAIEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}

private final VoyageAIEmbeddingsModel model;

private VoyageAIEmbeddingsRequestManager(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = Objects.requireNonNull(model);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(docsInput, model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;

import java.util.Map;
import java.util.Objects;

abstract class VoyageAIRequestManager extends BaseRequestManager {
private static final String DEFAULT_MODEL_FAMILY = "default_model_family";
private static final Map<String, String> MODEL_TO_MODEL_FAMILY = Map.of(
"voyage-multimodal-3",
"embed_multimodal",
"voyage-3-large",
"embed_large",
"voyage-code-3",
"embed_large",
"voyage-3",
"embed_medium",
"voyage-3-lite",
"embed_small",
"voyage-finance-2",
"embed_large",
"voyage-law-2",
"embed_large",
"voyage-code-2",
"embed_large",
"rerank-2",
"rerank_large",
"rerank-2-lite",
"rerank_small"
);

protected VoyageAIRequestManager(ThreadPool threadPool, VoyageAIModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
}

record RateLimitGrouping(int apiKeyHash) {
public static RateLimitGrouping of(VoyageAIModel model) {
Objects.requireNonNull(model);
String modelId = model.getServiceSettings().modelId();
String modelFamily = MODEL_TO_MODEL_FAMILY.getOrDefault(modelId, DEFAULT_MODEL_FAMILY);

return new RateLimitGrouping(modelFamily.hashCode());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequest;
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIRerankResponseEntity;
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;

import java.util.Objects;
import java.util.function.Supplier;

public class VoyageAIRerankRequestManager extends VoyageAIRequestManager {
private static final Logger logger = LogManager.getLogger(VoyageAIRerankRequestManager.class);
private static final ResponseHandler HANDLER = createVoyageAIResponseHandler();

private static ResponseHandler createVoyageAIResponseHandler() {
return new VoyageAIResponseHandler("voyageai rerank", (request, response) -> VoyageAIRerankResponseEntity.fromResponse(response));
}

public static VoyageAIRerankRequestManager of(VoyageAIRerankModel model, ThreadPool threadPool) {
return new VoyageAIRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}

private final VoyageAIRerankModel model;

private VoyageAIRerankRequestManager(VoyageAIRerankModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = model;
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
VoyageAIRerankRequest request = new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Loading