Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/132388.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 132388
summary: Added NVIDIA support to Inference Plugin
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9189000
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.3.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
inference_cached_tokens,9200000
ml_inference_nvidia_added,9189000
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.nvidia.embeddings.NvidiaEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.nvidia.rerank.NvidiaRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
Expand Down Expand Up @@ -170,6 +173,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addCustomNamedWriteables(namedWriteables);
addLlamaNamedWriteables(namedWriteables);
addAi21NamedWriteables(namedWriteables);
addNvidiaNamedWriteables(namedWriteables);

addUnifiedNamedWriteables(namedWriteables);

Expand Down Expand Up @@ -305,6 +309,27 @@ private static void addAi21NamedWriteables(List<NamedWriteableRegistry.Entry> na
// no task settings for AI21
}

private static void addNvidiaNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
NvidiaChatCompletionServiceSettings.NAME,
NvidiaChatCompletionServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
NvidiaEmbeddingsServiceSettings.NAME,
NvidiaEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, NvidiaRerankServiceSettings.NAME, NvidiaRerankServiceSettings::new)
);
// no task settings for Nvidia
}

private static void addAzureAiStudioNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
import org.elasticsearch.xpack.inference.services.llama.LlamaService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.nvidia.NvidiaService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService;
Expand Down Expand Up @@ -426,6 +427,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context),
context -> new LlamaService(httpFactory.get(), serviceComponents.get(), context),
context -> new Ai21Service(httpFactory.get(), serviceComponents.get(), context),
context -> new NvidiaService(httpFactory.get(), serviceComponents.get(), context),
ElasticsearchInternalService::new,
context -> new CustomService(httpFactory.get(), serviceComponents.get(), context)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public void setURI(String newUri) {
/**
* Retrieves the secret settings from the provided map of secrets.
* If the map is null or empty, it returns an instance of EmptySecretSettings.
* Caused by the fact that Llama model doesn't have out of the box security settings and can be used witout authentication.
* Caused by the fact that Llama model doesn't have out of the box security settings and can be used without authentication.
*
* @param secrets the map containing secret settings
* @return an instance of SecretSettings
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.nvidia;

import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.util.Objects;

/**
* Abstract class representing an Nvidia model for inference.
* This class extends RateLimitGroupingModel and provides common functionality for Nvidia models.
*/
public abstract class NvidiaModel extends RateLimitGroupingModel {
/**
* Constructor for creating a NvidiaModel with specified configurations and secrets.
*
* @param configurations the model configurations
* @param secrets the secret settings for the model
*/
protected NvidiaModel(ModelConfigurations configurations, ModelSecrets secrets) {
super(configurations, secrets);
}

/**
* Constructor for creating a NvidiaModel with specified model, service settings, and secret settings.
* @param model the model configurations
* @param serviceSettings the settings for the inference service
*/
protected NvidiaModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) {
super(model, serviceSettings);
}

@Override
public RateLimitSettings rateLimitSettings() {
return getServiceSettings().rateLimitSettings();
}

@Override
public int rateLimitGroupingHash() {
return Objects.hash(getServiceSettings().uri(), getServiceSettings().modelId());
}

@Override
public NvidiaServiceSettings getServiceSettings() {
return (NvidiaServiceSettings) super.getServiceSettings();
}

@Override
public DefaultSecretSettings getSecretSettings() {
return (DefaultSecretSettings) super.getSecretSettings();
}

}
Loading