|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import logging |
| 16 | +from pathlib import Path |
| 17 | +from typing import List, Optional, Union |
| 18 | + |
| 19 | +import numpy as np |
| 20 | +from pytriton.decorators import batch, first_value |
| 21 | +from pytriton.model_config import Tensor |
| 22 | +from tensorrt_llm import SamplingParams |
| 23 | +from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig |
| 24 | +from tensorrt_llm.llmapi.llm import LLM, TokenizerBase |
| 25 | +from transformers import PreTrainedTokenizerBase |
| 26 | + |
| 27 | +from nemo_deploy import ITritonDeployable |
| 28 | +from nemo_deploy.utils import cast_output, str_ndarray2list |
| 29 | + |
| 30 | +LOGGER = logging.getLogger("NeMo") |
| 31 | + |
| 32 | + |
| 33 | +class TensorRTLLMAPIDeployable(ITritonDeployable): |
| 34 | + """A Triton inference server compatible wrapper for TensorRT-LLM LLM-API. |
| 35 | +
|
| 36 | + This class provides a standardized interface for deploying TensorRT-LLM LLM-API |
| 37 | + in Triton inference server. It handles model loading, inference, and deployment configurations. |
| 38 | +
|
| 39 | + Args: |
| 40 | + hf_model_id_path (str): Path to the HuggingFace model or model identifier. |
| 41 | + Can be a local path or a model ID from HuggingFace Hub. |
| 42 | + tokenizer (Optional[Union[str, Path, TokenizerBase, PreTrainedTokenizerBase]]): |
| 43 | + Path to the tokenizer or tokenizer instance. |
| 44 | + tensor_parallel_size (int): Tensor parallelism size. Defaults to 1. |
| 45 | + pipeline_parallel_size (int): Pipeline parallelism size. Defaults to 1. |
| 46 | + moe_expert_parallel_size (int): MOE expert parallelism size. Defaults to -1. |
| 47 | + moe_tensor_parallel_size (int): MOE tensor parallelism size. Defaults to -1. |
| 48 | + max_batch_size (int): Maximum batch size. Defaults to 8. |
| 49 | + max_num_tokens (int): Maximum total tokens across all sequences in a batch. Defaults to 8192. |
| 50 | + backend (str): Backend to use for TRTLLM. Defaults to "pytorch". |
| 51 | + dtype (str): Model data type. Defaults to "auto". |
| 52 | + **kwargs: Additional keyword arguments to pass to model loading. |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + hf_model_id_path: str, |
| 58 | + tokenizer: Optional[Union[str, Path, TokenizerBase, PreTrainedTokenizerBase]] = None, |
| 59 | + tensor_parallel_size: int = 1, |
| 60 | + pipeline_parallel_size: int = 1, |
| 61 | + moe_expert_parallel_size: int = -1, |
| 62 | + moe_tensor_parallel_size: int = -1, |
| 63 | + max_batch_size: int = 8, |
| 64 | + max_num_tokens: int = 8192, |
| 65 | + backend: str = "pytorch", |
| 66 | + dtype: str = "auto", |
| 67 | + **kwargs, |
| 68 | + ): |
| 69 | + config_args = {k: kwargs.pop(k) for k in PyTorchConfig.__annotations__.keys() & kwargs.keys()} |
| 70 | + pytorch_config = PyTorchConfig(**config_args) |
| 71 | + |
| 72 | + self.model = LLM( |
| 73 | + model=hf_model_id_path, |
| 74 | + tokenizer=hf_model_id_path if tokenizer is None else tokenizer, |
| 75 | + tensor_parallel_size=tensor_parallel_size, |
| 76 | + pipeline_parallel_size=pipeline_parallel_size, |
| 77 | + moe_expert_parallel_size=moe_expert_parallel_size, |
| 78 | + moe_tensor_parallel_size=moe_tensor_parallel_size, |
| 79 | + max_batch_size=max_batch_size, |
| 80 | + max_num_tokens=max_num_tokens, |
| 81 | + backend=backend, |
| 82 | + dtype=dtype, |
| 83 | + pytorch_backend_config=pytorch_config, |
| 84 | + **kwargs, |
| 85 | + ) |
| 86 | + |
| 87 | + def generate( |
| 88 | + self, |
| 89 | + prompts: List[str], |
| 90 | + max_length: int = 256, |
| 91 | + temperature: Optional[float] = None, |
| 92 | + top_k: Optional[int] = None, |
| 93 | + top_p: Optional[float] = None, |
| 94 | + **kwargs, |
| 95 | + ) -> List[str]: |
| 96 | + """Generate text based on the provided input prompts. |
| 97 | +
|
| 98 | + This method processes input prompts through the loaded model and |
| 99 | + generates text according to the specified parameters. |
| 100 | +
|
| 101 | + Args: |
| 102 | + prompts: List of input prompts |
| 103 | + max_length: Maximum number of tokens to generate. Defaults to 256. |
| 104 | + temperature: Sampling temperature. Defaults to None. |
| 105 | + top_k: Number of highest probability tokens to consider. Defaults to None. |
| 106 | + top_p: Cumulative probability threshold for token sampling. Defaults to None. |
| 107 | + **kwargs: Additional keyword arguments to sampling params. |
| 108 | +
|
| 109 | + Returns: |
| 110 | + List[str]: A list of generated texts, one for each input prompt. |
| 111 | +
|
| 112 | + Raises: |
| 113 | + RuntimeError: If the model is not initialized. |
| 114 | + """ |
| 115 | + if not self.model: |
| 116 | + raise RuntimeError("Model is not initialized") |
| 117 | + |
| 118 | + sampling_params = SamplingParams( |
| 119 | + max_tokens=max_length, |
| 120 | + temperature=temperature, |
| 121 | + top_k=top_k, |
| 122 | + top_p=top_p, |
| 123 | + **kwargs, |
| 124 | + ) |
| 125 | + |
| 126 | + outputs = self.model.generate( |
| 127 | + inputs=prompts, |
| 128 | + sampling_params=sampling_params, |
| 129 | + ) |
| 130 | + |
| 131 | + return [output.outputs[0].text for output in outputs] |
| 132 | + |
| 133 | + @property |
| 134 | + def get_triton_input(self): |
| 135 | + inputs = ( |
| 136 | + Tensor(name="prompts", shape=(-1,), dtype=bytes), |
| 137 | + Tensor(name="max_length", shape=(-1,), dtype=np.int_, optional=True), |
| 138 | + Tensor(name="max_batch_size", shape=(-1,), dtype=np.int_, optional=True), |
| 139 | + Tensor(name="top_k", shape=(-1,), dtype=np.int_, optional=True), |
| 140 | + Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True), |
| 141 | + Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True), |
| 142 | + ) |
| 143 | + return inputs |
| 144 | + |
| 145 | + @property |
| 146 | + def get_triton_output(self): |
| 147 | + return (Tensor(name="sentences", shape=(-1,), dtype=bytes),) |
| 148 | + |
| 149 | + @batch |
| 150 | + @first_value("temperature", "top_k", "top_p", "max_length") |
| 151 | + def triton_infer_fn(self, **inputs: np.ndarray): |
| 152 | + output_infer = {} |
| 153 | + |
| 154 | + prompts = str_ndarray2list(inputs.pop("prompts")) |
| 155 | + temperature = inputs.pop("temperature", None) |
| 156 | + top_k = inputs.pop("top_k", None) |
| 157 | + top_p = inputs.pop("top_p", None) |
| 158 | + max_length = inputs.pop("max_length", 256) |
| 159 | + |
| 160 | + output = self.generate( |
| 161 | + prompts=prompts, |
| 162 | + temperature=temperature, |
| 163 | + top_k=top_k, |
| 164 | + top_p=top_p, |
| 165 | + max_length=max_length, |
| 166 | + ) |
| 167 | + |
| 168 | + output_infer = {"sentences": cast_output(output, np.bytes_)} |
| 169 | + |
| 170 | + return output_infer |
0 commit comments