Skip to content

Commit 0da2fab

Browse files
davidkyleJoanFMjonathan-buttner
authored
[Inference API] Add Jina AI API to do inference for Embedding and Rerank models (#118652) (#119752)
# Conflicts: # x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java Co-authored-by: Joan Fontanals <[email protected]> Co-authored-by: Jonathan Buttner <[email protected]>
1 parent d18ce7d commit 0da2fab

File tree

49 files changed

+6784
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+6784
-0
lines changed

docs/changelog/118652.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 118652
2+
summary: Add Jina AI API to do inference for Embedding and Rerank models
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@
7575
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
7676
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
7777
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
78+
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings;
79+
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings;
80+
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
81+
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings;
82+
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
7883
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
7984
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
8085
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
@@ -132,6 +137,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
132137
addAmazonBedrockNamedWriteables(namedWriteables);
133138
addEisNamedWriteables(namedWriteables);
134139
addAlibabaCloudSearchNamedWriteables(namedWriteables);
140+
addJinaAINamedWriteables(namedWriteables);
135141

136142
addUnifiedNamedWriteables(namedWriteables);
137143

@@ -569,6 +575,28 @@ private static void addAlibabaCloudSearchNamedWriteables(List<NamedWriteableRegi
569575

570576
}
571577

578+
private static void addJinaAINamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
579+
namedWriteables.add(
580+
new NamedWriteableRegistry.Entry(ServiceSettings.class, JinaAIServiceSettings.NAME, JinaAIServiceSettings::new)
581+
);
582+
namedWriteables.add(
583+
new NamedWriteableRegistry.Entry(
584+
ServiceSettings.class,
585+
JinaAIEmbeddingsServiceSettings.NAME,
586+
JinaAIEmbeddingsServiceSettings::new
587+
)
588+
);
589+
namedWriteables.add(
590+
new NamedWriteableRegistry.Entry(TaskSettings.class, JinaAIEmbeddingsTaskSettings.NAME, JinaAIEmbeddingsTaskSettings::new)
591+
);
592+
namedWriteables.add(
593+
new NamedWriteableRegistry.Entry(ServiceSettings.class, JinaAIRerankServiceSettings.NAME, JinaAIRerankServiceSettings::new)
594+
);
595+
namedWriteables.add(
596+
new NamedWriteableRegistry.Entry(TaskSettings.class, JinaAIRerankTaskSettings.NAME, JinaAIRerankTaskSettings::new)
597+
);
598+
}
599+
572600
private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
573601
namedWriteables.add(
574602
new NamedWriteableRegistry.Entry(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
112112
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
113113
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
114+
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
114115
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
115116
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
116117
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
@@ -289,6 +290,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
289290
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()),
290291
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
291292
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
293+
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
292294
ElasticsearchInternalService::new
293295
);
294296
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.action.jinaai;
9+
10+
import org.elasticsearch.inference.InputType;
11+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
12+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
13+
import org.elasticsearch.xpack.inference.external.http.sender.JinaAIEmbeddingsRequestManager;
14+
import org.elasticsearch.xpack.inference.external.http.sender.JinaAIRerankRequestManager;
15+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
16+
import org.elasticsearch.xpack.inference.services.ServiceComponents;
17+
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
18+
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel;
19+
20+
import java.util.Map;
21+
import java.util.Objects;
22+
23+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
24+
25+
/**
26+
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the jinaai model type.
27+
*/
28+
public class JinaAIActionCreator implements JinaAIActionVisitor {
29+
private final Sender sender;
30+
private final ServiceComponents serviceComponents;
31+
32+
public JinaAIActionCreator(Sender sender, ServiceComponents serviceComponents) {
33+
this.sender = Objects.requireNonNull(sender);
34+
this.serviceComponents = Objects.requireNonNull(serviceComponents);
35+
}
36+
37+
@Override
38+
public ExecutableAction create(JinaAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
39+
var overriddenModel = JinaAIEmbeddingsModel.of(model, taskSettings, inputType);
40+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
41+
overriddenModel.getServiceSettings().getCommonSettings().uri(),
42+
"JinaAI embeddings"
43+
);
44+
var requestCreator = JinaAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
45+
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
46+
}
47+
48+
@Override
49+
public ExecutableAction create(JinaAIRerankModel model, Map<String, Object> taskSettings) {
50+
var overriddenModel = JinaAIRerankModel.of(model, taskSettings);
51+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
52+
overriddenModel.getServiceSettings().getCommonSettings().uri(),
53+
"JinaAI rerank"
54+
);
55+
var requestCreator = JinaAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
56+
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
57+
}
58+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.action.jinaai;
9+
10+
import org.elasticsearch.inference.InputType;
11+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
12+
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
13+
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel;
14+
15+
import java.util.Map;
16+
17+
public interface JinaAIActionVisitor {
18+
ExecutableAction create(JinaAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
19+
20+
ExecutableAction create(JinaAIRerankModel model, Map<String, Object> taskSettings);
21+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpResult.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,9 @@ private static byte[] limitBody(ByteSizeValue maxResponseSize, HttpResponse resp
4747
public boolean isBodyEmpty() {
4848
return body().length == 0;
4949
}
50+
51+
public boolean isSuccessfulResponse() {
52+
var code = response.getStatusLine().getStatusCode();
53+
return code >= 200 && code < 300;
54+
}
5055
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.jinaai.JinaAIResponseHandler;
18+
import org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIEmbeddingsRequest;
19+
import org.elasticsearch.xpack.inference.external.response.jinaai.JinaAIEmbeddingsResponseEntity;
20+
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
21+
22+
import java.util.List;
23+
import java.util.Objects;
24+
import java.util.function.Supplier;
25+
26+
public class JinaAIEmbeddingsRequestManager extends JinaAIRequestManager {
27+
private static final Logger logger = LogManager.getLogger(JinaAIEmbeddingsRequestManager.class);
28+
private static final ResponseHandler HANDLER = createEmbeddingsHandler();
29+
30+
private static ResponseHandler createEmbeddingsHandler() {
31+
return new JinaAIResponseHandler("jinaai text embedding", JinaAIEmbeddingsResponseEntity::fromResponse);
32+
}
33+
34+
public static JinaAIEmbeddingsRequestManager of(JinaAIEmbeddingsModel model, ThreadPool threadPool) {
35+
return new JinaAIEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
36+
}
37+
38+
private final JinaAIEmbeddingsModel model;
39+
40+
private JinaAIEmbeddingsRequestManager(JinaAIEmbeddingsModel model, ThreadPool threadPool) {
41+
super(threadPool, model);
42+
this.model = Objects.requireNonNull(model);
43+
}
44+
45+
@Override
46+
public void execute(
47+
InferenceInputs inferenceInputs,
48+
RequestSender requestSender,
49+
Supplier<Boolean> hasRequestCompletedFunction,
50+
ActionListener<InferenceServiceResults> listener
51+
) {
52+
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
53+
JinaAIEmbeddingsRequest request = new JinaAIEmbeddingsRequest(docsInput, model);
54+
55+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
56+
}
57+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.elasticsearch.threadpool.ThreadPool;
11+
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIModel;
12+
13+
import java.util.Objects;
14+
15+
abstract class JinaAIRequestManager extends BaseRequestManager {
16+
17+
protected JinaAIRequestManager(ThreadPool threadPool, JinaAIModel model) {
18+
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
19+
}
20+
21+
record RateLimitGrouping(int apiKeyHash) {
22+
public static RateLimitGrouping of(JinaAIModel model) {
23+
Objects.requireNonNull(model);
24+
25+
return new RateLimitGrouping(model.apiKey().hashCode());
26+
}
27+
}
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.jinaai.JinaAIResponseHandler;
18+
import org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIRerankRequest;
19+
import org.elasticsearch.xpack.inference.external.response.jinaai.JinaAIRerankResponseEntity;
20+
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel;
21+
22+
import java.util.Objects;
23+
import java.util.function.Supplier;
24+
25+
public class JinaAIRerankRequestManager extends JinaAIRequestManager {
26+
private static final Logger logger = LogManager.getLogger(JinaAIRerankRequestManager.class);
27+
private static final ResponseHandler HANDLER = createJinaAIResponseHandler();
28+
29+
private static ResponseHandler createJinaAIResponseHandler() {
30+
return new JinaAIResponseHandler("jinaai rerank", (request, response) -> JinaAIRerankResponseEntity.fromResponse(response));
31+
}
32+
33+
public static JinaAIRerankRequestManager of(JinaAIRerankModel model, ThreadPool threadPool) {
34+
return new JinaAIRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
35+
}
36+
37+
private final JinaAIRerankModel model;
38+
39+
private JinaAIRerankRequestManager(JinaAIRerankModel model, ThreadPool threadPool) {
40+
super(threadPool, model);
41+
this.model = model;
42+
}
43+
44+
@Override
45+
public void execute(
46+
InferenceInputs inferenceInputs,
47+
RequestSender requestSender,
48+
Supplier<Boolean> hasRequestCompletedFunction,
49+
ActionListener<InferenceServiceResults> listener
50+
) {
51+
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
52+
JinaAIRerankRequest request = new JinaAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
53+
54+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
55+
}
56+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.jinaai;
9+
10+
import org.elasticsearch.common.CheckedSupplier;
11+
import org.elasticsearch.common.settings.SecureString;
12+
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIModel;
13+
14+
import java.net.URI;
15+
import java.net.URISyntaxException;
16+
import java.util.Objects;
17+
18+
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;
19+
20+
public record JinaAIAccount(URI uri, SecureString apiKey) {
21+
22+
public static JinaAIAccount of(JinaAIModel model, CheckedSupplier<URI, URISyntaxException> uriBuilder) {
23+
var uri = buildUri(model.uri(), "JinaAI", uriBuilder);
24+
25+
return new JinaAIAccount(uri, model.apiKey());
26+
}
27+
28+
public JinaAIAccount {
29+
Objects.requireNonNull(uri);
30+
Objects.requireNonNull(apiKey);
31+
}
32+
}

0 commit comments

Comments
 (0)