diff --git a/docs/changelog/132388.yaml b/docs/changelog/132388.yaml new file mode 100644 index 0000000000000..98571ba91a4b1 --- /dev/null +++ b/docs/changelog/132388.yaml @@ -0,0 +1,5 @@ +pr: 132388 +summary: Added NVIDIA support to Inference Plugin +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/resources/transport/definitions/referable/ml_inference_nvidia_added.csv b/server/src/main/resources/transport/definitions/referable/ml_inference_nvidia_added.csv new file mode 100644 index 0000000000000..95dace8502cbb --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/ml_inference_nvidia_added.csv @@ -0,0 +1 @@ +9189000 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv index 311c14ca764ac..758234b9f03f3 100644 --- a/server/src/main/resources/transport/upper_bounds/9.3.csv +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -1 +1 @@ -inference_cached_tokens,9200000 +ml_inference_nvidia_added,9189000 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index e3c7b829cbedd..3ea53dbce8aa6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -109,6 +109,9 @@ import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.nvidia.embeddings.NvidiaEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.nvidia.rerank.NvidiaRerankServiceSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings; @@ -170,6 +173,7 @@ public static List getNamedWriteables() { addCustomNamedWriteables(namedWriteables); addLlamaNamedWriteables(namedWriteables); addAi21NamedWriteables(namedWriteables); + addNvidiaNamedWriteables(namedWriteables); addUnifiedNamedWriteables(namedWriteables); @@ -305,6 +309,27 @@ private static void addAi21NamedWriteables(List na // no task settings for AI21 } + private static void addNvidiaNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + NvidiaChatCompletionServiceSettings.NAME, + NvidiaChatCompletionServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + NvidiaEmbeddingsServiceSettings.NAME, + NvidiaEmbeddingsServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(ServiceSettings.class, NvidiaRerankServiceSettings.NAME, NvidiaRerankServiceSettings::new) + ); + // no task settings for Nvidia + } + private static void addAzureAiStudioNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 60592c5dd1dbd..38041b12d0779 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -143,6 +143,7 @@ import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService; import org.elasticsearch.xpack.inference.services.llama.LlamaService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; +import org.elasticsearch.xpack.inference.services.nvidia.NvidiaService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService; @@ -426,6 +427,7 @@ public List getInferenceServiceFactories() { context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context), context -> new LlamaService(httpFactory.get(), serviceComponents.get(), context), context -> new Ai21Service(httpFactory.get(), serviceComponents.get(), context), + context -> new NvidiaService(httpFactory.get(), serviceComponents.get(), context), ElasticsearchInternalService::new, context -> new CustomService(httpFactory.get(), serviceComponents.get(), context) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java index 3e24d058d8540..9b0b9d464fec2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java @@ -76,7 +76,7 @@ public void setURI(String newUri) { /** * Retrieves the secret settings from the provided map of secrets. * If the map is null or empty, it returns an instance of EmptySecretSettings. - * Caused by the fact that Llama model doesn't have out of the box security settings and can be used witout authentication. + * Caused by the fact that Llama model doesn't have out of the box security settings and can be used without authentication. * * @param secrets the map containing secret settings * @return an instance of SecretSettings diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaModel.java new file mode 100644 index 0000000000000..c4b62cbcb66a6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaModel.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia; + +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.Objects; + +/** + * Abstract class representing an Nvidia model for inference. + * This class extends RateLimitGroupingModel and provides common functionality for Nvidia models. + */ +public abstract class NvidiaModel extends RateLimitGroupingModel { + /** + * Constructor for creating a NvidiaModel with specified configurations and secrets. + * + * @param configurations the model configurations + * @param secrets the secret settings for the model + */ + protected NvidiaModel(ModelConfigurations configurations, ModelSecrets secrets) { + super(configurations, secrets); + } + + /** + * Constructor for creating a NvidiaModel with specified model, service settings, and secret settings. + * @param model the model configurations + * @param serviceSettings the settings for the inference service + */ + protected NvidiaModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return getServiceSettings().rateLimitSettings(); + } + + @Override + public int rateLimitGroupingHash() { + return Objects.hash(getServiceSettings().uri(), getServiceSettings().modelId()); + } + + @Override + public NvidiaServiceSettings getServiceSettings() { + return (NvidiaServiceSettings) super.getServiceSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaService.java new file mode 100644 index 0000000000000..ad2ed72d6e896 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaService.java @@ -0,0 +1,387 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.nvidia.action.NvidiaActionCreator; +import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.nvidia.embeddings.NvidiaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.nvidia.request.completion.NvidiaChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.nvidia.rerank.NvidiaRerankModel; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; +import static org.elasticsearch.inference.TaskType.COMPLETION; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +/** + * NvidiaService is an inference service for Nvidia models, supporting text embedding and chat completion tasks. + * It extends SenderService to handle HTTP requests and responses for Nvidia models. + */ +public class NvidiaService extends SenderService implements RerankingInferenceService { + public static final String NAME = "nvidia"; + private static final String SERVICE_NAME = "Nvidia"; + /** + * The optimal batch size depends on the hardware the model is deployed on. + * For Nvidia use a conservatively small max batch size as it is + * unknown how the model is deployed + */ + static final int EMBEDDING_MAX_BATCH_SIZE = 20; + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.COMPLETION, + TaskType.CHAT_COMPLETION, + TaskType.RERANK + ); + private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new NvidiaChatCompletionResponseHandler( + "Nvidia chat completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + /** + * Constructor for creating a NvidiaService with specified HTTP request sender factory and service components. + * + * @param factory the factory to create HTTP request senders + * @param serviceComponents the components required for the inference service + * @param context the context for the inference service factory + */ + public NvidiaService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public NvidiaService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); + } + + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + var actionCreator = new NvidiaActionCreator(getSender(), getServiceComponents()); + switch (model) { + case NvidiaChatCompletionModel nvidiaChatCompletionModel -> nvidiaChatCompletionModel.accept(actionCreator) + .execute(inputs, timeout, listener); + case NvidiaEmbeddingsModel nvidiaEmbeddingsModel -> nvidiaEmbeddingsModel.accept(actionCreator) + .execute(inputs, timeout, listener); + case NvidiaRerankModel nvidiaRerankModel -> nvidiaRerankModel.accept(actionCreator).execute(inputs, timeout, listener); + default -> listener.onFailure(createInvalidModelException(model)); + } + } + + @Override + protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { + ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException); + } + + /** + * Creates a NvidiaModel based on the provided parameters. + * + * @param inferenceId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param serviceSettings the settings for the inference service + * @param chunkingSettings the settings for chunking, if applicable + * @param secretSettings the secret settings for the model, such as API keys or tokens + * @param context the context for parsing configuration settings + * @return a new instance of NvidiaModel based on the provided parameters + */ + protected NvidiaModel createModel( + String inferenceId, + TaskType taskType, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secretSettings, + ConfigurationParseContext context + ) { + return switch (taskType) { + case CHAT_COMPLETION, COMPLETION -> new NvidiaChatCompletionModel( + inferenceId, + taskType, + NAME, + serviceSettings, + secretSettings, + context + ); + case TEXT_EMBEDDING -> new NvidiaEmbeddingsModel( + inferenceId, + taskType, + NAME, + serviceSettings, + chunkingSettings, + secretSettings, + context + ); + case RERANK -> new NvidiaRerankModel(inferenceId, taskType, NAME, serviceSettings, secretSettings, context); + default -> throw createInvalidTaskTypeException(inferenceId, NAME, taskType, context); + }; + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof NvidiaChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + var nvidiaChatCompletionModel = (NvidiaChatCompletionModel) model; + var overriddenModel = NvidiaChatCompletionModel.of(nvidiaChatCompletionModel, inputs.getRequest()); + var manager = new GenericRequestManager<>( + getServiceComponents().threadPool(), + overriddenModel, + UNIFIED_CHAT_COMPLETION_HANDLER, + unifiedChatInput -> new NvidiaChatCompletionRequest(unifiedChatInput, overriddenModel), + UnifiedChatInput.class + ); + var errorMessage = NvidiaActionCreator.buildErrorMessage(CHAT_COMPLETION, model.getInferenceEntityId()); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); + + action.execute(inputs, timeout, listener); + } + + @Override + protected void doChunkedInfer( + Model model, + List inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + throw new UnsupportedOperationException("Nvidia service does not support chunked inference"); + } + + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(COMPLETION, CHAT_COMPLETION); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES; + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + + NvidiaModel model = createModel( + inferenceId, + taskType, + serviceSettingsMap, + chunkingSettings, + serviceSettingsMap, + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public Model parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, chunkingSettings, secretSettingsMap); + } + + private NvidiaModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secretSettings + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + chunkingSettings, + secretSettings, + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, chunkingSettings, null); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return NvidiaUtils.ML_INFERENCE_NVIDIA_ADDED; + } + + @Override + public boolean hideFromConfigurationApi() { + // The Nvidia service is very configurable so we're going to hide it from being exposed in the service API. + return true; + } + + @Override + public int rerankerWindowSize(String modelId) { + // As Nvidia does not publish the max input length for their reranking models, we use a conservative default. + return CONSERVATIVE_DEFAULT_WINDOW_SIZE; + } + + /** + * Configuration class for the Nvidia inference service. + * It provides the settings and configurations required for the service. + */ + public static class Configuration { + public static InferenceServiceConfiguration get() { + return CONFIGURATION.getOrCompute(); + } + + private Configuration() {} + + private static final LazyInitializable CONFIGURATION = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + URL, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.") + .setLabel("URL") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + configurationMap.put( + MODEL_ID, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription( + "Refer to the Nvidia models documentation for the list of available models." + ) + .setLabel("Model") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaServiceSettings.java new file mode 100644 index 0000000000000..00f47c40f351a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaServiceSettings.java @@ -0,0 +1,115 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; + +/** + * Represents the settings for an Nvidia service. + * This class encapsulates the model ID, URI, and rate limit settings for the Nvidia service. + */ +public abstract class NvidiaServiceSettings extends FilteredXContentObject implements ServiceSettings { + // There is no default rate limit for Nvidia, so we set a reasonable default of 3000 requests per minute + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + protected final String modelId; + protected final URI uri; + protected final RateLimitSettings rateLimitSettings; + + /** + * Constructs a new NvidiaServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + protected NvidiaServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.uri = createOptionalUri(in.readOptionalString()); + this.rateLimitSettings = new RateLimitSettings(in); + } + + protected NvidiaServiceSettings(String modelId, @Nullable URI uri, @Nullable RateLimitSettings rateLimitSettings) { + this.modelId = modelId; + this.uri = uri; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + assert false : "should never be called when supportsVersion is used"; + return NvidiaUtils.ML_INFERENCE_NVIDIA_ADDED; + } + + @Override + public boolean supportsVersion(TransportVersion version) { + return NvidiaUtils.supportsNvidia(version); + } + + @Override + public String modelId() { + return this.modelId; + } + + /** + * Returns the URI of the Nvidia chat completion service. + * + * @return the URI of the service + */ + public URI uri() { + return this.uri; + } + + /** + * Returns the rate limit settings for the Nvidia chat completion service. + * + * @return the rate limit settings + */ + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeOptionalString(uri != null ? uri.toString() : null); + rateLimitSettings.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + if (uri != null) { + builder.field(URL, uri.toString()); + } + rateLimitSettings.toXContent(builder, params); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaUtils.java new file mode 100644 index 0000000000000..3eb0f29fed2f6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaUtils.java @@ -0,0 +1,22 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia; + +import org.elasticsearch.TransportVersion; + +public final class NvidiaUtils { + + public static final TransportVersion ML_INFERENCE_NVIDIA_ADDED = TransportVersion.fromName("ml_inference_nvidia_added"); + + public static boolean supportsNvidia(TransportVersion version) { + return version.supports(ML_INFERENCE_NVIDIA_ADDED); + } + + private NvidiaUtils() {} + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/action/NvidiaActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/action/NvidiaActionCreator.java new file mode 100644 index 0000000000000..b3dc2cfbc601c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/action/NvidiaActionCreator.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.action; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.nvidia.embeddings.NvidiaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.nvidia.embeddings.NvidiaEmbeddingsResponseHandler; +import org.elasticsearch.xpack.inference.services.nvidia.request.completion.NvidiaChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.nvidia.request.embeddings.NvidiaEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.nvidia.rerank.NvidiaRerankModel; +import org.elasticsearch.xpack.inference.services.nvidia.rerank.NvidiaRerankResponseHandler; +import org.elasticsearch.xpack.inference.services.nvidia.response.NvidiaRankedResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity; + +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +/** + * Creates actions for Nvidia inference requests, handling both embeddings and completions. + * This class implements the {@link NvidiaActionVisitor} interface to provide specific action creation methods. + */ +public class NvidiaActionCreator implements NvidiaActionVisitor { + + private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = "Failed to send Nvidia %s request from inference entity id [%s]"; + private static final String COMPLETION_ERROR_PREFIX = "Nvidia completions"; + private static final String USER_ROLE = "user"; + + private static final ResponseHandler EMBEDDINGS_HANDLER = new NvidiaEmbeddingsResponseHandler( + "Nvidia text embedding", + OpenAiEmbeddingsResponseEntity::fromResponse + ); + + private static final ResponseHandler COMPLETION_HANDLER = new NvidiaCompletionResponseHandler( + "Nvidia completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + private static final ResponseHandler RERANK_HANDLER = new NvidiaRerankResponseHandler( + "Nvidia rerank", + (request, response) -> NvidiaRankedResponseEntity.fromResponse(response), + false + ); + + private final Sender sender; + private final ServiceComponents serviceComponents; + + /** + * Constructs a new NvidiaActionCreator with the specified sender and service components. + * + * @param sender the sender to use for executing actions + * @param serviceComponents the service components providing necessary services + */ + public NvidiaActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(NvidiaEmbeddingsModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + EMBEDDINGS_HANDLER, + embeddingsInput -> new NvidiaEmbeddingsRequest( + serviceComponents.truncator(), + truncate(embeddingsInput.getStringInputs(), model.getServiceSettings().maxInputTokens()), + model + ), + EmbeddingsInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId()); + return new SenderExecutableAction(sender, manager, errorMessage); + } + + @Override + public ExecutableAction create(NvidiaChatCompletionModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + COMPLETION_HANDLER, + inputs -> new NvidiaChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId()); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX); + } + + @Override + public ExecutableAction create(NvidiaRerankModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + RERANK_HANDLER, + inputs -> new NvidiaRerankRequest(inputs.getQuery(), inputs.getChunks(), model), + QueryAndDocsInputs.class + ); + var errorMessage = buildErrorMessage(TaskType.RERANK, model.getInferenceEntityId()); + return new SenderExecutableAction(sender, manager, errorMessage); + } + + /** + * Builds an error message for failed requests. + * + * @param requestType the type of request that failed + * @param inferenceId the inference entity ID associated with the request + * @return a formatted error message + */ + public static String buildErrorMessage(TaskType requestType, String inferenceId) { + return format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, requestType.toString(), inferenceId); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/action/NvidiaActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/action/NvidiaActionVisitor.java new file mode 100644 index 0000000000000..617a50fd2f1ce --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/action/NvidiaActionVisitor.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.action; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.nvidia.embeddings.NvidiaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.nvidia.rerank.NvidiaRerankModel; + +/** + * Visitor interface for creating executable actions for Nvidia inference services. + */ +public interface NvidiaActionVisitor { + /** + * Creates an executable action for the given Nvidia embeddings model. + * + * @param model the Nvidia embeddings model + * @return an executable action for the embeddings model + */ + ExecutableAction create(NvidiaEmbeddingsModel model); + + /** + * Creates an executable action for the given Nvidia chat completion model. + * + * @param model the Nvidia chat completion model + * @return an executable action for the chat completion model + */ + ExecutableAction create(NvidiaChatCompletionModel model); + + /** + * Creates an executable action for the given Nvidia rerank model. + * + * @param model The Nvidia rerank model. + * @return An executable action for the rerank model. + */ + ExecutableAction create(NvidiaRerankModel model); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/completion/NvidiaChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/completion/NvidiaChatCompletionModel.java new file mode 100644 index 0000000000000..d6dae2fe9b8b0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/completion/NvidiaChatCompletionModel.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.completion; + +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.nvidia.NvidiaModel; +import org.elasticsearch.xpack.inference.services.nvidia.action.NvidiaActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.Map; + +/** + * Represents a Nvidia chat completion model for inference. + * This class extends the NvidiaModel and provides specific configurations and settings for chat completion tasks. + */ +public class NvidiaChatCompletionModel extends NvidiaModel { + + /** + * Constructor for creating a NvidiaChatCompletionModel with specified parameters. + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to chat completion + * @param secrets the secret settings for the model, such as API keys or tokens + * @param context the context for parsing configuration settings + */ + public NvidiaChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + NvidiaChatCompletionServiceSettings.fromMap(serviceSettings, context), + DefaultSecretSettings.fromMap(secrets) + ); + } + + /** + * Constructor for creating a NvidiaChatCompletionModel with specified parameters. + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to chat completion + * @param secrets the secret settings for the model, such as API keys or tokens + */ + public NvidiaChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + NvidiaChatCompletionServiceSettings serviceSettings, + SecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE), + new ModelSecrets(secrets) + ); + } + + /** + * Factory method to create a NvidiaChatCompletionModel with overridden model settings based on the request. + * If the request does not specify a model, the original model is returned. + * + * @param model the original NvidiaChatCompletionModel + * @param request the UnifiedCompletionRequest containing potential overrides + * @return a new NvidiaChatCompletionModel with overridden settings or the original model if no overrides are specified + */ + public static NvidiaChatCompletionModel of(NvidiaChatCompletionModel model, UnifiedCompletionRequest request) { + if (request.model() == null) { + // If no model id is specified in the request, return the original model + return model; + } + + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new NvidiaChatCompletionServiceSettings( + request.model(), + originalModelServiceSettings.uri(), + originalModelServiceSettings.rateLimitSettings() + ); + + return new NvidiaChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + model.getSecretSettings() + ); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return getServiceSettings().rateLimitSettings(); + } + + /** + * Returns the service settings specific to Nvidia chat completion. + * + * @return the NvidiaChatCompletionServiceSettings associated with this model + */ + @Override + public NvidiaChatCompletionServiceSettings getServiceSettings() { + return (NvidiaChatCompletionServiceSettings) super.getServiceSettings(); + } + + /** + * Accepts a visitor that creates an executable action for this Nvidia chat completion model. + * + * @param creator the visitor that creates the executable action + * @return an ExecutableAction representing this model + */ + public ExecutableAction accept(NvidiaActionVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/completion/NvidiaChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/completion/NvidiaChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..fef3068b4274f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/completion/NvidiaChatCompletionResponseHandler.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.completion; + +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponseUtils; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; + +/** + * Handles streaming chat completion responses and error parsing for Nvidia inference endpoints. + * This handler is designed to work with the unified Nvidia chat completion API. + */ +public class NvidiaChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { + + private static final String NVIDIA_ERROR = "nvidia_error"; + private static final UnifiedChatCompletionErrorParserContract NVIDIA_ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils + .createErrorParserWithStringify(NVIDIA_ERROR); + + /** + * Constructor for creating a NvidiaChatCompletionResponseHandler with specified request type and response parser. + * + * @param requestType the type of request this handler will process + * @param parseFunction the function to parse the response + */ + public NvidiaChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, NVIDIA_ERROR_PARSER::parse, NVIDIA_ERROR_PARSER); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/completion/NvidiaChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/completion/NvidiaChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..21142f330df36 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/completion/NvidiaChatCompletionServiceSettings.java @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.completion; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.nvidia.NvidiaService; +import org.elasticsearch.xpack.inference.services.nvidia.NvidiaServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +/** + * Represents the settings for an Nvidia chat completion service. + * This class encapsulates the model ID, URI, and rate limit settings for the Nvidia chat completion service. + */ +public class NvidiaChatCompletionServiceSettings extends NvidiaServiceSettings { + public static final String NAME = "nvidia_chat_completion_service_settings"; + + /** + * Creates a new instance of NvidiaChatCompletionServiceSettings from a map of settings. + * + * @param map the map containing the service settings + * @param context the context for parsing configuration settings + * @return a new instance of NvidiaChatCompletionServiceSettings + * @throws ValidationException if required fields are missing or invalid + */ + public static NvidiaChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractOptionalUri(map, URL, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + NvidiaService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new NvidiaChatCompletionServiceSettings(model, uri, rateLimitSettings); + } + + /** + * Constructs a new NvidiaChatCompletionServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public NvidiaChatCompletionServiceSettings(StreamInput in) throws IOException { + super(in); + } + + /** + * Constructs a new NvidiaChatCompletionServiceSettings with the specified model ID, URI, and rate limit settings. + * + * @param modelId the ID of the model + * @param uri the URI of the service + * @param rateLimitSettings the rate limit settings for the service + */ + public NvidiaChatCompletionServiceSettings(String modelId, @Nullable URI uri, @Nullable RateLimitSettings rateLimitSettings) { + super(modelId, uri, rateLimitSettings); + } + + /** + * Constructs a new NvidiaChatCompletionServiceSettings with the specified model ID and URL. + * The rate limit settings will be set to the default value. + * + * @param modelId the ID of the model + * @param url the URL of the service + * @param rateLimitSettings the rate limit settings for the service, can be null + */ + public NvidiaChatCompletionServiceSettings(String modelId, @Nullable String url, @Nullable RateLimitSettings rateLimitSettings) { + this(modelId, createOptionalUri(url), rateLimitSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NvidiaChatCompletionServiceSettings that = (NvidiaChatCompletionServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/completion/NvidiaCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/completion/NvidiaCompletionResponseHandler.java new file mode 100644 index 0000000000000..5e35bf8809650 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/completion/NvidiaCompletionResponseHandler.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.completion; + +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; + +/** + * Handles non-streaming completion responses for Nvidia inference endpoints, extending the OpenAI completion response handler. + */ +public class NvidiaCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { + + /** + * Constructs a NvidiaCompletionResponseHandler with the specified request type and response parser. + * + * @param requestType The type of request being handled. + * @param parseFunction The function to parse the response. + */ + public NvidiaCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, ErrorResponse::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/embeddings/NvidiaEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/embeddings/NvidiaEmbeddingsModel.java new file mode 100644 index 0000000000000..029febcf94b7c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/embeddings/NvidiaEmbeddingsModel.java @@ -0,0 +1,108 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.embeddings; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.nvidia.NvidiaModel; +import org.elasticsearch.xpack.inference.services.nvidia.action.NvidiaActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +/** + * Represents an Nvidia embeddings model for inference. + * This class extends the NvidiaModel and provides specific configurations and settings for embeddings tasks. + */ +public class NvidiaEmbeddingsModel extends NvidiaModel { + + /** + * Constructor for creating a NvidiaEmbeddingsModel with specified parameters. + * + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to embeddings + * @param chunkingSettings the chunking settings for processing input data + * @param secrets the secret settings for the model, such as API keys or tokens + * @param context the context for parsing configuration settings + */ + public NvidiaEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + NvidiaEmbeddingsServiceSettings.fromMap(serviceSettings, context), + chunkingSettings, + DefaultSecretSettings.fromMap(secrets) + ); + } + + /** + * Constructor for creating a NvidiaEmbeddingsModel with specified parameters. + * + * @param model the base NvidiaEmbeddingsModel to copy properties from + * @param serviceSettings the settings for the inference service, specific to embeddings + */ + public NvidiaEmbeddingsModel(NvidiaEmbeddingsModel model, NvidiaEmbeddingsServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + /** + * Constructor for creating a NvidiaEmbeddingsModel with specified parameters. + * + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to embeddings + * @param chunkingSettings the chunking settings for processing input data + * @param secrets the secret settings for the model, such as API keys or tokens + */ + public NvidiaEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + NvidiaEmbeddingsServiceSettings serviceSettings, + ChunkingSettings chunkingSettings, + SecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings), + new ModelSecrets(secrets) + ); + } + + @Override + public NvidiaEmbeddingsServiceSettings getServiceSettings() { + return (NvidiaEmbeddingsServiceSettings) super.getServiceSettings(); + } + + /** + * Accepts a visitor to create an executable action for this Nvidia embeddings model. + * + * @param creator the visitor that creates the executable action + * @return an ExecutableAction representing the Nvidia embeddings model + */ + public ExecutableAction accept(NvidiaActionVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/embeddings/NvidiaEmbeddingsResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/embeddings/NvidiaEmbeddingsResponseHandler.java new file mode 100644 index 0000000000000..3b7cc902c7215 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/embeddings/NvidiaEmbeddingsResponseHandler.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.embeddings; + +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.openai.OpenAiResponseHandler; + +/** + * Handles responses for Nvidia embeddings requests, parsing the response and handling errors. + * This class extends OpenAiResponseHandler to provide specific functionality for Nvidia embeddings. + */ +public class NvidiaEmbeddingsResponseHandler extends OpenAiResponseHandler { + + /** + * Constructs a new NvidiaEmbeddingsResponseHandler with the specified request type and response parser. + * + * @param requestType the type of request this handler will process + * @param parseFunction the function to parse the response + */ + public NvidiaEmbeddingsResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, ErrorResponse::fromResponse, false); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/embeddings/NvidiaEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/embeddings/NvidiaEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..8127916e217ce --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/embeddings/NvidiaEmbeddingsServiceSettings.java @@ -0,0 +1,199 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.embeddings; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.nvidia.NvidiaService; +import org.elasticsearch.xpack.inference.services.nvidia.NvidiaServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; + +/** + * Settings for the Nvidia embeddings service. + * This class encapsulates the configuration settings required to use Nvidia for generating embeddings. + */ +public class NvidiaEmbeddingsServiceSettings extends NvidiaServiceSettings { + public static final String NAME = "nvidia_embeddings_service_settings"; + + private final SimilarityMeasure similarity; + private final Integer maxInputTokens; + + /** + * Creates a new instance of NvidiaEmbeddingsServiceSettings from a map of settings. + * + * @param map the map containing the settings + * @param context the context for parsing configuration settings + * @return a new instance of NvidiaEmbeddingsServiceSettings + * @throws ValidationException if any required fields are missing or invalid + */ + public static NvidiaEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractOptionalUri(map, URL, validationException); + var similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + var maxInputTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + var rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException, NvidiaService.NAME, context); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new NvidiaEmbeddingsServiceSettings(model, uri, similarity, maxInputTokens, rateLimitSettings); + } + + /** + * Constructs a new NvidiaEmbeddingsServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public NvidiaEmbeddingsServiceSettings(StreamInput in) throws IOException { + super(in); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.maxInputTokens = in.readOptionalVInt(); + } + + /** + * Constructs a new NvidiaEmbeddingsServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param uri the URI of the Nvidia service + * @param similarity the similarity measure to use, can be null + * @param maxInputTokens the maximum number of input tokens, can be null + * @param rateLimitSettings the rate limit settings for the service, can be null + */ + public NvidiaEmbeddingsServiceSettings( + String modelId, + @Nullable URI uri, + @Nullable SimilarityMeasure similarity, + @Nullable Integer maxInputTokens, + @Nullable RateLimitSettings rateLimitSettings + ) { + super(modelId, uri, rateLimitSettings); + this.similarity = similarity; + this.maxInputTokens = maxInputTokens; + } + + /** + * Constructs a new NvidiaEmbeddingsServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param url the URL of the Nvidia service + * @param similarity the similarity measure to use, can be null + * @param maxInputTokens the maximum number of input tokens, can be null + * @param rateLimitSettings the rate limit settings for the service, can be null + */ + public NvidiaEmbeddingsServiceSettings( + String modelId, + @Nullable String url, + @Nullable SimilarityMeasure similarity, + @Nullable Integer maxInputTokens, + @Nullable RateLimitSettings rateLimitSettings + ) { + this(modelId, createOptionalUri(url), similarity, maxInputTokens, rateLimitSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public SimilarityMeasure similarity() { + return this.similarity; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + /** + * Returns the maximum number of input tokens allowed for this service. + * + * @return the maximum input tokens, or null if not specified + */ + public Integer maxInputTokens() { + return this.maxInputTokens; + } + + /** + * Returns the rate limit settings for this service. + * + * @return the rate limit settings, never null + */ + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + out.writeOptionalVInt(maxInputTokens); + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + super.toXContentFragmentOfExposedFields(builder, params); + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NvidiaEmbeddingsServiceSettings that = (NvidiaEmbeddingsServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(similarity, that.similarity) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, maxInputTokens, similarity, rateLimitSettings); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/completion/NvidiaChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/completion/NvidiaChatCompletionRequest.java new file mode 100644 index 0000000000000..1daee6804e528 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/completion/NvidiaChatCompletionRequest.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.request.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaChatCompletionModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * Nvidia Chat Completion Request + * This class is responsible for creating a request to the Nvidia chat completion model. + * It constructs an HTTP POST request with the necessary headers and body content. + */ +public class NvidiaChatCompletionRequest implements Request { + + private final NvidiaChatCompletionModel model; + private final UnifiedChatInput chatInput; + + public NvidiaChatCompletionRequest(UnifiedChatInput chatInput, NvidiaChatCompletionModel model) { + this.chatInput = Objects.requireNonNull(chatInput); + this.model = Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.getServiceSettings().uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new NvidiaChatCompletionRequestEntity(chatInput, model.getServiceSettings().modelId())) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return model.getServiceSettings().uri(); + } + + @Override + public Request truncate() { + // No truncation for Nvidia chat completions + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for Nvidia chat completions + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public boolean isStreaming() { + return chatInput.stream(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/completion/NvidiaChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/completion/NvidiaChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..f2cbde6fcc4a8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/completion/NvidiaChatCompletionRequestEntity.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.request.completion; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; + +import java.io.IOException; +import java.util.Objects; + +/** + * NvidiaChatCompletionRequestEntity is responsible for creating the request entity for Nvidia chat completion. + * It implements ToXContentObject to allow serialization to XContent format. + */ +public class NvidiaChatCompletionRequestEntity implements ToXContentObject { + + private final String modelId; + private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; + + public NvidiaChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) { + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); + this.modelId = Objects.requireNonNull(modelId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(modelId, params)); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/embeddings/NvidiaEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/embeddings/NvidiaEmbeddingsRequest.java new file mode 100644 index 0000000000000..f618cfaeaffc2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/embeddings/NvidiaEmbeddingsRequest.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.request.embeddings; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.nvidia.embeddings.NvidiaEmbeddingsModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * Nvidia Embeddings Request + * This class is responsible for creating a request to the Nvidia embeddings endpoint. + * It constructs an HTTP POST request with the necessary headers and body content. + */ +public class NvidiaEmbeddingsRequest implements Request { + private final NvidiaEmbeddingsModel model; + private final Truncator.TruncationResult truncationResult; + private final Truncator truncator; + + /** + * Constructs a new NvidiaEmbeddingsRequest with the specified truncator, input, and model. + * + * @param truncator the truncator to handle input truncation + * @param input the input to be truncated + * @param model the Nvidia embeddings model to be used for the request + */ + public NvidiaEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, NvidiaEmbeddingsModel model) { + this.model = Objects.requireNonNull(model); + this.truncator = Objects.requireNonNull(truncator); + this.truncationResult = Objects.requireNonNull(input); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.getServiceSettings().uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new NvidiaEmbeddingsRequestEntity(truncationResult.input(), model.getServiceSettings().modelId())) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return model.getServiceSettings().uri(); + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + return new NvidiaEmbeddingsRequest(truncator, truncatedInput, model); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/embeddings/NvidiaEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/embeddings/NvidiaEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..bc4560c903180 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/embeddings/NvidiaEmbeddingsRequestEntity.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.request.embeddings; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +/** + * NvidiaEmbeddingsRequestEntity is responsible for creating the request entity for Nvidia embeddings. + * It implements ToXContentObject to allow serialization to XContent format. + */ +public record NvidiaEmbeddingsRequestEntity(List input, String modelId) implements ToXContentObject { + + private static final String INPUT_FIELD = "input"; + private static final String MODEL_FIELD = "model"; + + public NvidiaEmbeddingsRequestEntity { + Objects.requireNonNull(input); + Objects.requireNonNull(modelId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INPUT_FIELD, input); + builder.field(MODEL_FIELD, modelId); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/rerank/NvidiaRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/rerank/NvidiaRerankRequest.java new file mode 100644 index 0000000000000..7127055597608 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/rerank/NvidiaRerankRequest.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.request.rerank; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.nvidia.rerank.NvidiaRerankModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * Represents a request to the Nvidia rerank service. + * This class constructs the HTTP request with the necessary headers and body content. + * @param query the query string to rerank against + * @param input the list of input documents to be reranked + * @param model the Nvidia rerank model configuration + */ +public record NvidiaRerankRequest(String query, List input, NvidiaRerankModel model) implements Request { + + public NvidiaRerankRequest { + Objects.requireNonNull(input); + Objects.requireNonNull(query); + Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(getURI()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new NvidiaRerankRequestEntity(model.getServiceSettings().modelId(), query, input)) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + + httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return model.getServiceSettings().uri(); + } + + @Override + public Request truncate() { + // Not applicable for rerank, only used in text embedding requests + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // Not applicable for rerank, only used in text embedding requests + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/rerank/NvidiaRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/rerank/NvidiaRerankRequestEntity.java new file mode 100644 index 0000000000000..9bf4371e461f5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/request/rerank/NvidiaRerankRequestEntity.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.request.rerank; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +/** + * Entity representing the request body for Nvidia rerank requests. + * @param modelId the model identifier + * @param query the query string + * @param passages the list of passages to be reranked + */ +public record NvidiaRerankRequestEntity(String modelId, String query, List passages) implements ToXContentObject { + + private static final String MODEL_FIELD = "model"; + private static final String QUERY_FIELD = "query"; + private static final String PASSAGES_FIELD = "passages"; + private static final String TEXT_FIELD = "text"; + + public NvidiaRerankRequestEntity { + Objects.requireNonNull(modelId); + Objects.requireNonNull(query); + Objects.requireNonNull(passages); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(MODEL_FIELD, modelId); + + builder.startObject(QUERY_FIELD); + builder.field(TEXT_FIELD, query); + builder.endObject(); + + builder.startArray(PASSAGES_FIELD); + for (String passage : passages) { + builder.startObject(); + builder.field(TEXT_FIELD, passage); + builder.endObject(); + } + builder.endArray(); + + builder.endObject(); + return builder; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/rerank/NvidiaRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/rerank/NvidiaRerankModel.java new file mode 100644 index 0000000000000..9aa4be59927e3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/rerank/NvidiaRerankModel.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.rerank; + +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.nvidia.NvidiaModel; +import org.elasticsearch.xpack.inference.services.nvidia.action.NvidiaActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +/** + * Represents an Nvidia embeddings model for inference. + * This class extends the NvidiaModel and provides specific configurations and settings for embeddings tasks. + */ +public class NvidiaRerankModel extends NvidiaModel { + + /** + * Constructor for creating a NvidiaEmbeddingsModel with specified parameters. + * + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to embeddings + * @param secrets the secret settings for the model, such as API keys or tokens + * @param context the context for parsing configuration settings + */ + public NvidiaRerankModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + NvidiaRerankServiceSettings.fromMap(serviceSettings, context), + DefaultSecretSettings.fromMap(secrets) + ); + } + + /** + * Constructor for creating a NvidiaEmbeddingsModel with specified parameters. + * + * @param model the base NvidiaEmbeddingsModel to copy properties from + * @param serviceSettings the settings for the inference service, specific to embeddings + */ + public NvidiaRerankModel(NvidiaRerankModel model, NvidiaRerankServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + /** + * Constructor for creating a NvidiaEmbeddingsModel with specified parameters. + * + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to embeddings + * @param secrets the secret settings for the model, such as API keys or tokens + */ + public NvidiaRerankModel( + String inferenceEntityId, + TaskType taskType, + String service, + NvidiaRerankServiceSettings serviceSettings, + SecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE), + new ModelSecrets(secrets) + ); + } + + @Override + public NvidiaRerankServiceSettings getServiceSettings() { + return (NvidiaRerankServiceSettings) super.getServiceSettings(); + } + + /** + * Accepts a visitor to create an executable action for this Nvidia embeddings model. + * + * @param creator the visitor that creates the executable action + * @return an ExecutableAction representing the Nvidia embeddings model + */ + public ExecutableAction accept(NvidiaActionVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/rerank/NvidiaRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/rerank/NvidiaRerankServiceSettings.java new file mode 100644 index 0000000000000..87eb0b43a0bb3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/rerank/NvidiaRerankServiceSettings.java @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.nvidia.rerank; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.nvidia.NvidiaService; +import org.elasticsearch.xpack.inference.services.nvidia.NvidiaServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +/** + * Represents the settings for an Nvidia rerank service. + * This class encapsulates the model ID, URI, and rate limit settings for the Nvidia rerank service. + */ +public class NvidiaRerankServiceSettings extends NvidiaServiceSettings { + public static final String NAME = "nvidia_rerank_service_settings"; + + /** + * Creates a new instance of NvidiaRerankServiceSettings from a map of settings. + * + * @param map the map containing the service settings + * @param context the context for parsing configuration settings + * @return a new instance of NvidiaRerankServiceSettings + * @throws ValidationException if required fields are missing or invalid + */ + public static NvidiaRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractOptionalUri(map, URL, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + NvidiaService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new NvidiaRerankServiceSettings(model, uri, rateLimitSettings); + } + + /** + * Constructs a new NvidiaRerankServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public NvidiaRerankServiceSettings(StreamInput in) throws IOException { + super(in); + } + + /** + * Constructs a new NvidiaRerankServiceSettings with the specified model ID, URI, and rate limit settings. + * + * @param modelId the ID of the model + * @param uri the URI of the service + * @param rateLimitSettings the rate limit settings for the service + */ + public NvidiaRerankServiceSettings(String modelId, @Nullable URI uri, @Nullable RateLimitSettings rateLimitSettings) { + super(modelId, uri, rateLimitSettings); + } + + /** + * Constructs a new NvidiaRerankServiceSettings with the specified model ID and URL. + * The rate limit settings will be set to the default value. + * + * @param modelId the ID of the model + * @param url the URL of the service + * @param rateLimitSettings the rate limit settings for the service, can be null + */ + public NvidiaRerankServiceSettings(String modelId, @Nullable String url, @Nullable RateLimitSettings rateLimitSettings) { + this(modelId, createOptionalUri(url), rateLimitSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NvidiaRerankServiceSettings that = (NvidiaRerankServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, rateLimitSettings); + } +}