Skip to content

Commit 521f855

Browse files
feat: VoyageAI integration (elastic#122134)
* VoyageAI embeddings and rerank: - embeddings works, tested - initial rerank code What's missing: - unit and integration tests - rerank request/response mapping and verification * VoyageAI embeddings and rerank: - embeddings works, tested - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank) What's missing: - unit and integration tests * VoyageAI embeddings and rerank: - embeddings works, tested - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank) What's missing: - unit and integration tests * VoyageAI embeddings and rerank: - embeddings works, tested - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank) What's missing: - unit and integration tests * Adding initial tests Moving dimensions to ServiceSettings * Correcting the TransportVersions.java * Correcting due to comments * Adding BIT support * Initial tests * More tests * More tests/corrections * Removing warnings * Further tests * Transport version correction * Adding changelog and correcting TransportVersions * Spotless tests * Changes due to the comments * Changes due to the comments * Correcting QA tests * Correcting QA tests --------- Co-authored-by: Jonathan Buttner <[email protected]> Co-authored-by: Jonathan Buttner <[email protected]>
1 parent 171a3b9 commit 521f855

File tree

54 files changed

+8140
-5
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+8140
-5
lines changed

docs/changelog/122134.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 122134
2+
summary: Adding integration for VoyageAI embeddings and rerank models
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ static TransportVersion def(int id) {
180180
public static final TransportVersion REMOVE_ALL_APPLICABLE_SELECTOR_BACKPORT_8_19 = def(8_841_0_02);
181181
public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE_BACKPORT_8_19 = def(8_841_0_03);
182182
public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS_BACKPORT_8_19 = def(8_841_0_04);
183+
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05);
183184
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
184185
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
185186
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@@ -199,7 +200,7 @@ static TransportVersion def(int id) {
199200
public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS = def(9_011_0_00);
200201
public static final TransportVersion REMOVE_REPOSITORY_CONFLICT_MESSAGE = def(9_012_0_00);
201202
public static final TransportVersion RERANKER_FAILURES_ALLOWED = def(9_013_0_00);
202-
203+
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_014_0_00);
203204
/*
204205
* STOP! READ THIS FIRST! No, really,
205206
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2525
@SuppressWarnings("unchecked")
2626
public void testGetServicesWithoutTaskType() throws IOException {
2727
List<Object> services = getAllServices();
28-
assertThat(services.size(), equalTo(19));
28+
assertThat(services.size(), equalTo(20));
2929

3030
String[] providers = new String[services.size()];
3131
for (int i = 0; i < services.size(); i++) {
@@ -53,6 +53,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
5353
"test_reranking_service",
5454
"test_service",
5555
"text_embedding_test_service",
56+
"voyageai",
5657
"watsonxai"
5758
).toArray(),
5859
providers
@@ -62,7 +63,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
6263
@SuppressWarnings("unchecked")
6364
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
6465
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
65-
assertThat(services.size(), equalTo(14));
66+
assertThat(services.size(), equalTo(15));
6667

6768
String[] providers = new String[services.size()];
6869
for (int i = 0; i < services.size(); i++) {
@@ -85,6 +86,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
8586
"mistral",
8687
"openai",
8788
"text_embedding_test_service",
89+
"voyageai",
8890
"watsonxai"
8991
).toArray(),
9092
providers
@@ -94,7 +96,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
9496
@SuppressWarnings("unchecked")
9597
public void testGetServicesWithRerankTaskType() throws IOException {
9698
List<Object> services = getServices(TaskType.RERANK);
97-
assertThat(services.size(), equalTo(6));
99+
assertThat(services.size(), equalTo(7));
98100

99101
String[] providers = new String[services.size()];
100102
for (int i = 0; i < services.size(); i++) {
@@ -103,7 +105,8 @@ public void testGetServicesWithRerankTaskType() throws IOException {
103105
}
104106

105107
assertArrayEquals(
106-
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service").toArray(),
108+
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai")
109+
.toArray(),
107110
providers
108111
);
109112
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@
9090
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
9191
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
9292
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
93+
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
94+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
95+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
96+
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankServiceSettings;
97+
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
9398

9499
import java.util.ArrayList;
95100
import java.util.List;
@@ -142,6 +147,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
142147
addEisNamedWriteables(namedWriteables);
143148
addAlibabaCloudSearchNamedWriteables(namedWriteables);
144149
addJinaAINamedWriteables(namedWriteables);
150+
addVoyageAINamedWriteables(namedWriteables);
145151

146152
addUnifiedNamedWriteables(namedWriteables);
147153

@@ -626,6 +632,28 @@ private static void addJinaAINamedWriteables(List<NamedWriteableRegistry.Entry>
626632
);
627633
}
628634

635+
private static void addVoyageAINamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
636+
namedWriteables.add(
637+
new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIServiceSettings.NAME, VoyageAIServiceSettings::new)
638+
);
639+
namedWriteables.add(
640+
new NamedWriteableRegistry.Entry(
641+
ServiceSettings.class,
642+
VoyageAIEmbeddingsServiceSettings.NAME,
643+
VoyageAIEmbeddingsServiceSettings::new
644+
)
645+
);
646+
namedWriteables.add(
647+
new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIEmbeddingsTaskSettings.NAME, VoyageAIEmbeddingsTaskSettings::new)
648+
);
649+
namedWriteables.add(
650+
new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIRerankServiceSettings.NAME, VoyageAIRerankServiceSettings::new)
651+
);
652+
namedWriteables.add(
653+
new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIRerankTaskSettings.NAME, VoyageAIRerankTaskSettings::new)
654+
);
655+
}
656+
629657
private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
630658
namedWriteables.add(
631659
new NamedWriteableRegistry.Entry(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@
128128
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
129129
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
130130
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
131+
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
131132
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
132133

133134
import java.util.ArrayList;
@@ -359,6 +360,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
359360
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
360361
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
361362
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
363+
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
362364
ElasticsearchInternalService::new
363365
);
364366
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.action.voyageai;
9+
10+
import org.elasticsearch.inference.InputType;
11+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
12+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
13+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
14+
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager;
15+
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIRerankRequestManager;
16+
import org.elasticsearch.xpack.inference.services.ServiceComponents;
17+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
18+
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
19+
20+
import java.util.Map;
21+
import java.util.Objects;
22+
23+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
24+
25+
/**
26+
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the voyageai model type.
27+
*/
28+
public class VoyageAIActionCreator implements VoyageAIActionVisitor {
29+
private final Sender sender;
30+
private final ServiceComponents serviceComponents;
31+
32+
public VoyageAIActionCreator(Sender sender, ServiceComponents serviceComponents) {
33+
this.sender = Objects.requireNonNull(sender);
34+
this.serviceComponents = Objects.requireNonNull(serviceComponents);
35+
}
36+
37+
@Override
38+
public ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
39+
var overriddenModel = VoyageAIEmbeddingsModel.of(model, taskSettings, inputType);
40+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI embeddings");
41+
var requestCreator = VoyageAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
42+
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
43+
}
44+
45+
@Override
46+
public ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings) {
47+
var overriddenModel = VoyageAIRerankModel.of(model, taskSettings);
48+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI rerank");
49+
var requestCreator = VoyageAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
50+
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
51+
}
52+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.action.voyageai;
9+
10+
import org.elasticsearch.inference.InputType;
11+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
12+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
13+
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
14+
15+
import java.util.Map;
16+
17+
public interface VoyageAIActionVisitor {
18+
ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
19+
20+
ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings);
21+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest;
18+
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIEmbeddingsResponseEntity;
19+
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
20+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
21+
22+
import java.util.List;
23+
import java.util.Objects;
24+
import java.util.function.Supplier;
25+
26+
public class VoyageAIEmbeddingsRequestManager extends VoyageAIRequestManager {
27+
private static final Logger logger = LogManager.getLogger(VoyageAIEmbeddingsRequestManager.class);
28+
private static final ResponseHandler HANDLER = createEmbeddingsHandler();
29+
30+
private static ResponseHandler createEmbeddingsHandler() {
31+
return new VoyageAIResponseHandler("voyageai text embedding", VoyageAIEmbeddingsResponseEntity::fromResponse);
32+
}
33+
34+
public static VoyageAIEmbeddingsRequestManager of(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
35+
return new VoyageAIEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
36+
}
37+
38+
private final VoyageAIEmbeddingsModel model;
39+
40+
private VoyageAIEmbeddingsRequestManager(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
41+
super(threadPool, model);
42+
this.model = Objects.requireNonNull(model);
43+
}
44+
45+
@Override
46+
public void execute(
47+
InferenceInputs inferenceInputs,
48+
RequestSender requestSender,
49+
Supplier<Boolean> hasRequestCompletedFunction,
50+
ActionListener<InferenceServiceResults> listener
51+
) {
52+
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
53+
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(docsInput, model);
54+
55+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
56+
}
57+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.elasticsearch.threadpool.ThreadPool;
11+
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;
12+
13+
import java.util.Map;
14+
import java.util.Objects;
15+
16+
abstract class VoyageAIRequestManager extends BaseRequestManager {
17+
private static final String DEFAULT_MODEL_FAMILY = "default_model_family";
18+
private static final Map<String, String> MODEL_TO_MODEL_FAMILY = Map.of(
19+
"voyage-multimodal-3",
20+
"embed_multimodal",
21+
"voyage-3-large",
22+
"embed_large",
23+
"voyage-code-3",
24+
"embed_large",
25+
"voyage-3",
26+
"embed_medium",
27+
"voyage-3-lite",
28+
"embed_small",
29+
"voyage-finance-2",
30+
"embed_large",
31+
"voyage-law-2",
32+
"embed_large",
33+
"voyage-code-2",
34+
"embed_large",
35+
"rerank-2",
36+
"rerank_large",
37+
"rerank-2-lite",
38+
"rerank_small"
39+
);
40+
41+
protected VoyageAIRequestManager(ThreadPool threadPool, VoyageAIModel model) {
42+
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
43+
}
44+
45+
record RateLimitGrouping(int apiKeyHash) {
46+
public static RateLimitGrouping of(VoyageAIModel model) {
47+
Objects.requireNonNull(model);
48+
String modelId = model.getServiceSettings().modelId();
49+
String modelFamily = MODEL_TO_MODEL_FAMILY.getOrDefault(modelId, DEFAULT_MODEL_FAMILY);
50+
51+
return new RateLimitGrouping(modelFamily.hashCode());
52+
}
53+
}
54+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequest;
18+
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIRerankResponseEntity;
19+
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
20+
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
21+
22+
import java.util.Objects;
23+
import java.util.function.Supplier;
24+
25+
public class VoyageAIRerankRequestManager extends VoyageAIRequestManager {
26+
private static final Logger logger = LogManager.getLogger(VoyageAIRerankRequestManager.class);
27+
private static final ResponseHandler HANDLER = createVoyageAIResponseHandler();
28+
29+
private static ResponseHandler createVoyageAIResponseHandler() {
30+
return new VoyageAIResponseHandler("voyageai rerank", (request, response) -> VoyageAIRerankResponseEntity.fromResponse(response));
31+
}
32+
33+
public static VoyageAIRerankRequestManager of(VoyageAIRerankModel model, ThreadPool threadPool) {
34+
return new VoyageAIRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
35+
}
36+
37+
private final VoyageAIRerankModel model;
38+
39+
private VoyageAIRerankRequestManager(VoyageAIRerankModel model, ThreadPool threadPool) {
40+
super(threadPool, model);
41+
this.model = model;
42+
}
43+
44+
@Override
45+
public void execute(
46+
InferenceInputs inferenceInputs,
47+
RequestSender requestSender,
48+
Supplier<Boolean> hasRequestCompletedFunction,
49+
ActionListener<InferenceServiceResults> listener
50+
) {
51+
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
52+
VoyageAIRerankRequest request = new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
53+
54+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
55+
}
56+
}

0 commit comments

Comments
 (0)