Skip to content

Commit 539645d

Browse files
Jan-Kazlouski-elasticelasticsearchmachine
andauthored
Add NVIDIA support to Inference Plugin (#132388)
Creation of new NVIDIA inference provider integration allowing: text_embedding, completion, chat_completion, rerank Additional changes: * Refactor to replace CohereTruncation with Truncation across multiple classes for improved consistency and clarity --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 6f54921 commit 539645d

File tree

67 files changed

+7267
-118
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+7267
-118
lines changed

docs/changelog/132388.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 132388
2+
summary: Added NVIDIA support to Inference Plugin
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
{
2+
"inference.put_nvidia": {
3+
"documentation": {
4+
"url": "https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put-nvidia",
5+
"description": "Create an Nvidia inference endpoint"
6+
},
7+
"stability": "stable",
8+
"visibility": "public",
9+
"headers": {
10+
"accept": ["application/json"],
11+
"content_type": ["application/json"]
12+
},
13+
"url": {
14+
"paths": [
15+
{
16+
"path": "/_inference/{task_type}/{nvidia_inference_id}",
17+
"methods": ["PUT"],
18+
"parts": {
19+
"task_type": {
20+
"type": "enum",
21+
"description": "The task type",
22+
"options": [
23+
"rerank",
24+
"text_embedding",
25+
"completion",
26+
"chat_completion"
27+
]
28+
},
29+
"nvidia_inference_id": {
30+
"type": "string",
31+
"description": "The inference ID"
32+
}
33+
}
34+
}
35+
]
36+
},
37+
"body": {
38+
"description": "The inference endpoint's task and service settings",
39+
"required": true
40+
},
41+
"params": {
42+
"timeout": {
43+
"type": "time",
44+
"description": "Specifies the amount of time to wait for the inference endpoint to be created.",
45+
"default": "30s"
46+
}
47+
}
48+
}
49+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9233000
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
downsample_add_multi_field_sources,9232000
1+
ml_inference_nvidia_added,9233000

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
7070
"jinaai",
7171
"llama",
7272
"mistral",
73+
"nvidia",
7374
"openai",
7475
"openshift_ai",
7576
"streaming_completion_test_service",
@@ -117,6 +118,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
117118
"jinaai",
118119
"llama",
119120
"mistral",
121+
"nvidia",
120122
"openai",
121123
"openshift_ai",
122124
"text_embedding_test_service",
@@ -143,6 +145,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
143145
"elasticsearch",
144146
"googlevertexai",
145147
"jinaai",
148+
"nvidia",
146149
"openshift_ai",
147150
"test_reranking_service",
148151
"voyageai",
@@ -177,6 +180,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
177180
"hugging_face",
178181
"amazon_sagemaker",
179182
"mistral",
183+
"nvidia",
180184
"watsonxai"
181185
).toArray()
182186
)
@@ -200,6 +204,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
200204
"amazon_sagemaker",
201205
"googlevertexai",
202206
"mistral",
207+
"nvidia",
203208
"watsonxai"
204209
).toArray()
205210
)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@
114114
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings;
115115
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings;
116116
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
117+
import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaChatCompletionServiceSettings;
118+
import org.elasticsearch.xpack.inference.services.nvidia.embeddings.NvidiaEmbeddingsServiceSettings;
119+
import org.elasticsearch.xpack.inference.services.nvidia.embeddings.NvidiaEmbeddingsTaskSettings;
120+
import org.elasticsearch.xpack.inference.services.nvidia.rerank.NvidiaRerankServiceSettings;
117121
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
118122
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
119123
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
@@ -181,6 +185,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
181185
addLlamaNamedWriteables(namedWriteables);
182186
addAi21NamedWriteables(namedWriteables);
183187
addOpenShiftAiNamedWriteables(namedWriteables);
188+
addNvidiaNamedWriteables(namedWriteables);
184189

185190
addUnifiedNamedWriteables(namedWriteables);
186191

@@ -329,6 +334,29 @@ private static void addAi21NamedWriteables(List<NamedWriteableRegistry.Entry> na
329334
// no task settings for AI21
330335
}
331336

337+
private static void addNvidiaNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
338+
namedWriteables.add(
339+
new NamedWriteableRegistry.Entry(
340+
ServiceSettings.class,
341+
NvidiaChatCompletionServiceSettings.NAME,
342+
NvidiaChatCompletionServiceSettings::new
343+
)
344+
);
345+
namedWriteables.add(
346+
new NamedWriteableRegistry.Entry(
347+
ServiceSettings.class,
348+
NvidiaEmbeddingsServiceSettings.NAME,
349+
NvidiaEmbeddingsServiceSettings::new
350+
)
351+
);
352+
namedWriteables.add(
353+
new NamedWriteableRegistry.Entry(ServiceSettings.class, NvidiaRerankServiceSettings.NAME, NvidiaRerankServiceSettings::new)
354+
);
355+
namedWriteables.add(
356+
new NamedWriteableRegistry.Entry(TaskSettings.class, NvidiaEmbeddingsTaskSettings.NAME, NvidiaEmbeddingsTaskSettings::new)
357+
);
358+
}
359+
332360
private static void addAzureAiStudioNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
333361
namedWriteables.add(
334362
new NamedWriteableRegistry.Entry(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@
173173
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
174174
import org.elasticsearch.xpack.inference.services.llama.LlamaService;
175175
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
176+
import org.elasticsearch.xpack.inference.services.nvidia.NvidiaService;
176177
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
177178
import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiService;
178179
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
@@ -590,6 +591,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
590591
context -> new LlamaService(httpFactory.get(), serviceComponents.get(), context),
591592
context -> new Ai21Service(httpFactory.get(), serviceComponents.get(), context),
592593
context -> new OpenShiftAiService(httpFactory.get(), serviceComponents.get(), context),
594+
context -> new NvidiaService(httpFactory.get(), serviceComponents.get(), context),
593595
ElasticsearchInternalService::new,
594596
context -> new CustomService(httpFactory.get(), serviceComponents.get(), context)
595597
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java renamed to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/model/Truncation.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,19 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.services.cohere;
8+
package org.elasticsearch.xpack.inference.common.model;
99

1010
import java.util.EnumSet;
1111
import java.util.Locale;
1212

1313
/**
14-
* Defines the type of truncation for a cohere request. The specified value determines how the Cohere API will handle inputs
14+
* Defines the type of truncation for an embeddings request. The specified value determines how the provider's API will handle inputs
1515
* longer than the maximum token length.
16-
*
1716
* <p>
18-
* <a href="https://docs.cohere.com/reference/embed">See api docs for details.</a>
17+
* <a href="https://docs.cohere.com/reference/embed">Details can be found in Cohere embeddings API docs.</a>
1918
* </p>
2019
*/
21-
public enum CohereTruncation {
20+
public enum Truncation {
2221
/**
2322
* When the input exceeds the maximum input token length an error will be returned.
2423
*/
@@ -32,14 +31,14 @@ public enum CohereTruncation {
3231
*/
3332
END;
3433

35-
public static final EnumSet<CohereTruncation> ALL = EnumSet.allOf(CohereTruncation.class);
34+
public static final EnumSet<Truncation> ALL = EnumSet.allOf(Truncation.class);
3635

3736
@Override
3837
public String toString() {
3938
return name().toLowerCase(Locale.ROOT);
4039
}
4140

42-
public static CohereTruncation fromString(String name) {
41+
public static Truncation fromString(String name) {
4342
return valueOf(name.trim().toUpperCase(Locale.ROOT));
4443
}
4544
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ private static void checkProviderForTask(TaskType taskType, AmazonBedrockProvide
370370
}
371371

372372
private static void checkTaskSettingsForTextEmbeddingModel(AmazonBedrockEmbeddingsModel model) {
373-
if (model.provider() != AmazonBedrockProvider.COHERE && model.getTaskSettings().cohereTruncation() != null) {
373+
if (model.provider() != AmazonBedrockProvider.COHERE && model.getTaskSettings().truncation() != null) {
374374
throw new ElasticsearchStatusException(
375375
"The [{}] task type for provider [{}] does not allow [truncate] field",
376376
RestStatus.BAD_REQUEST,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsTaskSettings.java

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import org.elasticsearch.inference.ModelConfigurations;
1616
import org.elasticsearch.inference.TaskSettings;
1717
import org.elasticsearch.xcontent.XContentBuilder;
18-
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
18+
import org.elasticsearch.xpack.inference.common.model.Truncation;
1919

2020
import java.io.IOException;
2121
import java.util.HashMap;
@@ -24,8 +24,8 @@
2424
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
2525
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD;
2626

27-
public record AmazonBedrockEmbeddingsTaskSettings(@Nullable CohereTruncation cohereTruncation) implements TaskSettings {
28-
public static final AmazonBedrockEmbeddingsTaskSettings EMPTY = new AmazonBedrockEmbeddingsTaskSettings((CohereTruncation) null);
27+
public record AmazonBedrockEmbeddingsTaskSettings(@Nullable Truncation truncation) implements TaskSettings {
28+
public static final AmazonBedrockEmbeddingsTaskSettings EMPTY = new AmazonBedrockEmbeddingsTaskSettings((Truncation) null);
2929
public static final String NAME = "amazon_bedrock_embeddings_task_settings";
3030
private static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS = TransportVersion.fromName("amazon_bedrock_task_settings");
3131

@@ -36,36 +36,36 @@ public static AmazonBedrockEmbeddingsTaskSettings fromMap(Map<String, Object> ma
3636

3737
ValidationException validationException = new ValidationException();
3838

39-
var cohereTruncation = extractOptionalEnum(
39+
var extractedTruncation = extractOptionalEnum(
4040
map,
4141
TRUNCATE_FIELD,
4242
ModelConfigurations.TASK_SETTINGS,
43-
CohereTruncation::fromString,
44-
CohereTruncation.ALL,
43+
Truncation::fromString,
44+
Truncation.ALL,
4545
validationException
4646
);
4747

4848
if (validationException.validationErrors().isEmpty() == false) {
4949
throw validationException;
5050
}
5151

52-
return new AmazonBedrockEmbeddingsTaskSettings(cohereTruncation);
52+
return new AmazonBedrockEmbeddingsTaskSettings(extractedTruncation);
5353
}
5454

5555
public AmazonBedrockEmbeddingsTaskSettings(StreamInput in) throws IOException {
56-
this(in.readOptionalEnum(CohereTruncation.class));
56+
this(in.readOptionalEnum(Truncation.class));
5757
}
5858

5959
@Override
6060
public boolean isEmpty() {
61-
return cohereTruncation() == null;
61+
return truncation() == null;
6262
}
6363

6464
@Override
6565
public AmazonBedrockEmbeddingsTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
6666
var newTaskSettings = fromMap(new HashMap<>(newSettings));
6767

68-
return new AmazonBedrockEmbeddingsTaskSettings(firstNonNullOrNull(newTaskSettings.cohereTruncation(), cohereTruncation()));
68+
return new AmazonBedrockEmbeddingsTaskSettings(firstNonNullOrNull(newTaskSettings.truncation(), truncation()));
6969
}
7070

7171
private static <T> T firstNonNullOrNull(T first, T second) {
@@ -90,14 +90,14 @@ public boolean supportsVersion(TransportVersion version) {
9090

9191
@Override
9292
public void writeTo(StreamOutput out) throws IOException {
93-
out.writeOptionalEnum(cohereTruncation());
93+
out.writeOptionalEnum(truncation());
9494
}
9595

9696
@Override
9797
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
9898
builder.startObject();
99-
if (cohereTruncation != null) {
100-
builder.field(TRUNCATE_FIELD, cohereTruncation);
99+
if (truncation != null) {
100+
builder.field(TRUNCATE_FIELD, truncation);
101101
}
102102
return builder.endObject();
103103
}

0 commit comments

Comments
 (0)