Skip to content

Commit 7e639a2

Browse files
Add Nvidia integration for Completion and Chat Completion
1 parent ab7bd9b commit 7e639a2

File tree

13 files changed

+1069
-1
lines changed

13 files changed

+1069
-1
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ static TransportVersion def(int id) {
353353
public static final TransportVersion NODE_WEIGHTS_ADDED_TO_NODE_BALANCE_STATS = def(9_129_0_00);
354354
public static final TransportVersion RERANK_SNIPPETS = def(9_130_0_00);
355355
public static final TransportVersion PIPELINE_TRACKING_INFO = def(9_131_0_00);
356+
public static final TransportVersion ML_INFERENCE_NVIDIA_ADDED = def(9_132_0_00);
356357

357358
/*
358359
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings;
111111
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings;
112112
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
113+
import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaChatCompletionServiceSettings;
113114
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
114115
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
115116
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
@@ -178,6 +179,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
178179
addVoyageAINamedWriteables(namedWriteables);
179180
addCustomNamedWriteables(namedWriteables);
180181
addLlamaNamedWriteables(namedWriteables);
182+
addNvidiaNamedWriteables(namedWriteables);
181183

182184
addUnifiedNamedWriteables(namedWriteables);
183185

@@ -298,6 +300,17 @@ private static void addLlamaNamedWriteables(List<NamedWriteableRegistry.Entry> n
298300
// no task settings for Llama
299301
}
300302

303+
private static void addNvidiaNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
304+
namedWriteables.add(
305+
new NamedWriteableRegistry.Entry(
306+
ServiceSettings.class,
307+
NvidiaChatCompletionServiceSettings.NAME,
308+
NvidiaChatCompletionServiceSettings::new
309+
)
310+
);
311+
// no task settings for Nvidia
312+
}
313+
301314
private static void addAzureAiStudioNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
302315
namedWriteables.add(
303316
new NamedWriteableRegistry.Entry(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
136136
import org.elasticsearch.xpack.inference.services.llama.LlamaService;
137137
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
138+
import org.elasticsearch.xpack.inference.services.nvidia.NvidiaService;
138139
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
139140
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
140141
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService;
@@ -413,6 +414,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
413414
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context),
414415
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context),
415416
context -> new LlamaService(httpFactory.get(), serviceComponents.get(), context),
417+
context -> new NvidiaService(httpFactory.get(), serviceComponents.get(), context),
416418
ElasticsearchInternalService::new,
417419
context -> new CustomService(httpFactory.get(), serviceComponents.get(), context)
418420
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public void setURI(String newUri) {
7676
/**
7777
* Retrieves the secret settings from the provided map of secrets.
7878
* If the map is null or empty, it returns an instance of EmptySecretSettings.
79-
* Caused by the fact that Llama model doesn't have out of the box security settings and can be used witout authentication.
79+
* Caused by the fact that Llama model doesn't have out of the box security settings and can be used without authentication.
8080
*
8181
* @param secrets the map containing secret settings
8282
* @return an instance of SecretSettings
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.nvidia;
9+
10+
import org.elasticsearch.inference.ModelConfigurations;
11+
import org.elasticsearch.inference.ModelSecrets;
12+
import org.elasticsearch.inference.ServiceSettings;
13+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
14+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
15+
import org.elasticsearch.xpack.inference.services.nvidia.action.NvidiaActionVisitor;
16+
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
17+
18+
/**
19+
* Abstract class representing an Nvidia model for inference.
20+
* This class extends RateLimitGroupingModel and provides common functionality for Nvidia models.
21+
*/
22+
public abstract class NvidiaModel extends RateLimitGroupingModel {
23+
24+
/**
25+
* Constructor for creating a NvidiaModel with specified configurations and secrets.
26+
*
27+
* @param configurations the model configurations
28+
* @param secrets the secret settings for the model
29+
*/
30+
protected NvidiaModel(ModelConfigurations configurations, ModelSecrets secrets) {
31+
super(configurations, secrets);
32+
}
33+
34+
/**
35+
* Constructor for creating a NvidiaModel with specified model, service settings, and secret settings.
36+
* @param model the model configurations
37+
* @param serviceSettings the settings for the inference service
38+
*/
39+
protected NvidiaModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) {
40+
super(model, serviceSettings);
41+
}
42+
43+
@Override
44+
public DefaultSecretSettings getSecretSettings() {
45+
return (DefaultSecretSettings) super.getSecretSettings();
46+
}
47+
48+
/**
49+
* Accepts a visitor to create an executable action for the Nvidia model.
50+
* @param creator the visitor that creates the executable action
51+
* @return an executable action for the Nvidia model
52+
*/
53+
protected abstract ExecutableAction accept(NvidiaActionVisitor creator);
54+
}

0 commit comments

Comments
 (0)