Skip to content

Commit bcca709

Browse files
authored
Merge pull request #2 from voyage-ai/voyageai
feat: VoyageAI integration
2 parents db7fa3c + b9681af commit bcca709

File tree

51 files changed

+8001
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+8001
-0
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ static TransportVersion def(int id) {
186186
public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE = def(9_006_0_00);
187187
public static final TransportVersion ESQL_PROFILE_ASYNC_NANOS = def(9_007_00_0);
188188

189+
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_005_0_00);
189190
/*
190191
* STOP! READ THIS FIRST! No, really,
191192
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
128128
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
129129
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
130+
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
130131
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
131132

132133
import java.util.ArrayList;
@@ -357,6 +358,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
357358
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
358359
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
359360
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
361+
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
360362
ElasticsearchInternalService::new
361363
);
362364
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.action.voyageai;
9+
10+
import org.elasticsearch.inference.InputType;
11+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
12+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
13+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
14+
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager;
15+
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIRerankRequestManager;
16+
import org.elasticsearch.xpack.inference.services.ServiceComponents;
17+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
18+
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
19+
20+
import java.util.Map;
21+
import java.util.Objects;
22+
23+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
24+
25+
/**
26+
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the voyageai model type.
27+
*/
28+
public class VoyageAIActionCreator implements VoyageAIActionVisitor {
29+
private final Sender sender;
30+
private final ServiceComponents serviceComponents;
31+
32+
public VoyageAIActionCreator(Sender sender, ServiceComponents serviceComponents) {
33+
this.sender = Objects.requireNonNull(sender);
34+
this.serviceComponents = Objects.requireNonNull(serviceComponents);
35+
}
36+
37+
@Override
38+
public ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
39+
var overriddenModel = VoyageAIEmbeddingsModel.of(model, taskSettings, inputType);
40+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
41+
overriddenModel.getServiceSettings().getCommonSettings().uri(),
42+
"VoyageAI embeddings"
43+
);
44+
var requestCreator = VoyageAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
45+
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
46+
}
47+
48+
@Override
49+
public ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings) {
50+
var overriddenModel = VoyageAIRerankModel.of(model, taskSettings);
51+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
52+
overriddenModel.getServiceSettings().getCommonSettings().uri(),
53+
"VoyageAI rerank"
54+
);
55+
var requestCreator = VoyageAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
56+
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
57+
}
58+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.action.voyageai;
9+
10+
import org.elasticsearch.inference.InputType;
11+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
12+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
13+
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
14+
15+
import java.util.Map;
16+
17+
public interface VoyageAIActionVisitor {
18+
ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
19+
20+
ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings);
21+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest;
18+
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIEmbeddingsResponseEntity;
19+
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
20+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
21+
22+
import java.util.List;
23+
import java.util.Objects;
24+
import java.util.function.Supplier;
25+
26+
public class VoyageAIEmbeddingsRequestManager extends VoyageAIRequestManager {
27+
private static final Logger logger = LogManager.getLogger(VoyageAIEmbeddingsRequestManager.class);
28+
private static final ResponseHandler HANDLER = createEmbeddingsHandler();
29+
30+
private static ResponseHandler createEmbeddingsHandler() {
31+
return new VoyageAIResponseHandler("voyageai text embedding", VoyageAIEmbeddingsResponseEntity::fromResponse);
32+
}
33+
34+
public static VoyageAIEmbeddingsRequestManager of(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
35+
return new VoyageAIEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
36+
}
37+
38+
private final VoyageAIEmbeddingsModel model;
39+
40+
private VoyageAIEmbeddingsRequestManager(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
41+
super(threadPool, model);
42+
this.model = Objects.requireNonNull(model);
43+
}
44+
45+
@Override
46+
public void execute(
47+
InferenceInputs inferenceInputs,
48+
RequestSender requestSender,
49+
Supplier<Boolean> hasRequestCompletedFunction,
50+
ActionListener<InferenceServiceResults> listener
51+
) {
52+
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
53+
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(docsInput, model);
54+
55+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
56+
}
57+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.elasticsearch.threadpool.ThreadPool;
11+
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;
12+
13+
import java.util.Objects;
14+
15+
abstract class VoyageAIRequestManager extends BaseRequestManager {
16+
17+
protected VoyageAIRequestManager(ThreadPool threadPool, VoyageAIModel model) {
18+
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
19+
}
20+
21+
record RateLimitGrouping(int apiKeyHash) {
22+
public static RateLimitGrouping of(VoyageAIModel model) {
23+
Objects.requireNonNull(model);
24+
25+
return new RateLimitGrouping(model.apiKey().hashCode());
26+
}
27+
}
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequest;
18+
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIRerankResponseEntity;
19+
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
20+
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
21+
22+
import java.util.Objects;
23+
import java.util.function.Supplier;
24+
25+
public class VoyageAIRerankRequestManager extends VoyageAIRequestManager {
26+
private static final Logger logger = LogManager.getLogger(VoyageAIRerankRequestManager.class);
27+
private static final ResponseHandler HANDLER = createVoyageAIResponseHandler();
28+
29+
private static ResponseHandler createVoyageAIResponseHandler() {
30+
return new VoyageAIResponseHandler("voyageai rerank", (request, response) -> VoyageAIRerankResponseEntity.fromResponse(response));
31+
}
32+
33+
public static VoyageAIRerankRequestManager of(VoyageAIRerankModel model, ThreadPool threadPool) {
34+
return new VoyageAIRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
35+
}
36+
37+
private final VoyageAIRerankModel model;
38+
39+
private VoyageAIRerankRequestManager(VoyageAIRerankModel model, ThreadPool threadPool) {
40+
super(threadPool, model);
41+
this.model = model;
42+
}
43+
44+
@Override
45+
public void execute(
46+
InferenceInputs inferenceInputs,
47+
RequestSender requestSender,
48+
Supplier<Boolean> hasRequestCompletedFunction,
49+
ActionListener<InferenceServiceResults> listener
50+
) {
51+
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
52+
VoyageAIRerankRequest request = new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
53+
54+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
55+
}
56+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.voyageai;
9+
10+
import org.apache.http.client.methods.HttpPost;
11+
import org.apache.http.client.utils.URIBuilder;
12+
import org.apache.http.entity.ByteArrayEntity;
13+
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
15+
import org.elasticsearch.xpack.inference.external.request.Request;
16+
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount;
17+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
18+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
19+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
20+
21+
import java.net.URI;
22+
import java.net.URISyntaxException;
23+
import java.nio.charset.StandardCharsets;
24+
import java.util.List;
25+
import java.util.Objects;
26+
27+
public class VoyageAIEmbeddingsRequest extends VoyageAIRequest {
28+
29+
private final VoyageAIAccount account;
30+
private final List<String> input;
31+
private final VoyageAIEmbeddingsServiceSettings serviceSettings;
32+
private final VoyageAIEmbeddingsTaskSettings taskSettings;
33+
private final String model;
34+
private final String inferenceEntityId;
35+
36+
public VoyageAIEmbeddingsRequest(List<String> input, VoyageAIEmbeddingsModel embeddingsModel) {
37+
Objects.requireNonNull(embeddingsModel);
38+
39+
account = VoyageAIAccount.of(embeddingsModel, VoyageAIEmbeddingsRequest::buildDefaultUri);
40+
this.input = Objects.requireNonNull(input);
41+
serviceSettings = embeddingsModel.getServiceSettings();
42+
taskSettings = embeddingsModel.getTaskSettings();
43+
model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
44+
inferenceEntityId = embeddingsModel.getInferenceEntityId();
45+
}
46+
47+
@Override
48+
public HttpRequest createHttpRequest() {
49+
HttpPost httpPost = new HttpPost(account.uri());
50+
51+
ByteArrayEntity byteEntity = new ByteArrayEntity(
52+
Strings.toString(new VoyageAIEmbeddingsRequestEntity(input, serviceSettings, taskSettings, model))
53+
.getBytes(StandardCharsets.UTF_8)
54+
);
55+
httpPost.setEntity(byteEntity);
56+
57+
decorateWithAuthHeader(httpPost, account);
58+
59+
return new HttpRequest(httpPost, getInferenceEntityId());
60+
}
61+
62+
@Override
63+
public String getInferenceEntityId() {
64+
return inferenceEntityId;
65+
}
66+
67+
@Override
68+
public URI getURI() {
69+
return account.uri();
70+
}
71+
72+
@Override
73+
public Request truncate() {
74+
return this;
75+
}
76+
77+
@Override
78+
public boolean[] getTruncationInfo() {
79+
return null;
80+
}
81+
82+
public VoyageAIEmbeddingsTaskSettings getTaskSettings() {
83+
return taskSettings;
84+
}
85+
86+
public VoyageAIEmbeddingsServiceSettings getServiceSettings() {
87+
return serviceSettings;
88+
}
89+
90+
public static URI buildDefaultUri() throws URISyntaxException {
91+
return new URIBuilder().setScheme("https")
92+
.setHost(VoyageAIUtils.HOST)
93+
.setPathSegments(VoyageAIUtils.VERSION_1, VoyageAIUtils.EMBEDDINGS_PATH)
94+
.build();
95+
}
96+
}

0 commit comments

Comments
 (0)