Skip to content

Commit 0becc99

Browse files
Support TRTLLM LLM-API Deployment (#55)
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com> Co-authored-by: Onur Yilmaz <35306097+oyilmaz-nvidia@users.noreply.github.com> Co-authored-by: Onur Yilmaz <oyilmaz@nvidia.com>
1 parent 6eb3ab6 commit 0becc99

File tree

10 files changed

+777
-19
lines changed

10 files changed

+777
-19
lines changed

.github/workflows/cicd-main.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ jobs:
197197
runner: linux-amd64-gpu-rtxa6000-latest-2-nemo
198198
- script: L2_NeMo_2_Export_Qnemo_TRT_LLM
199199
runner: linux-amd64-gpu-rtxa6000-latest-2-nemo
200+
- script: L2_TRTLLM_API_Deploy_Query
201+
runner: linux-amd64-gpu-rtxa6000-latest-2-nemo
200202
needs: [cicd-unit-tests]
201203
runs-on: ${{ matrix.runner }}
202204
name: ${{ matrix.is_optional && 'PLEASEFIXME_' || '' }}${{ matrix.script }}

nemo_deploy/nlp/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
# limitations under the License.
1414

1515

16-
from nemo_deploy.nlp.query_llm import NemoQueryLLM, NemoQueryLLMHF, NemoQueryLLMPyTorch
16+
from nemo_deploy.nlp.query_llm import NemoQueryLLM, NemoQueryLLMHF, NemoQueryLLMPyTorch, NemoQueryTRTLLMAPI
1717

1818
__all__ = [
1919
"NemoQueryLLM",
2020
"NemoQueryLLMHF",
2121
"NemoQueryLLMPyTorch",
22+
"NemoQueryTRTLLMAPI",
2223
]

nemo_deploy/nlp/query_llm.py

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,6 @@ class NemoQueryLLMPyTorch(NemoQueryLLMBase):
6262
print("prompts: ", prompts)
6363
"""
6464

65-
def __init__(self, url, model_name):
66-
super().__init__(
67-
url=url,
68-
model_name=model_name,
69-
)
70-
7165
# these arguments are explicitly defined in order to make it clear to user what they can pass
7266
# names and optionality should exactly match the get_triton_input() results for MegatronGPTDeployable
7367
def query_llm(
@@ -204,12 +198,6 @@ class NemoQueryLLMHF(NemoQueryLLMBase):
204198
print("prompts: ", prompts)
205199
"""
206200

207-
def __init__(self, url, model_name):
208-
super().__init__(
209-
url=url,
210-
model_name=model_name,
211-
)
212-
213201
# these arguments are explicitly defined in order to make it clear to user what they can pass
214202
# names and optionality should exactly match the get_triton_input() results for HuggingFaceLLMDeploy
215203
def query_llm(
@@ -322,12 +310,6 @@ class NemoQueryLLM(NemoQueryLLMBase):
322310
print("prompts: ", prompts)
323311
"""
324312

325-
def __init__(self, url, model_name):
326-
super().__init__(
327-
url=url,
328-
model_name=model_name,
329-
)
330-
331313
def query_llm(
332314
self,
333315
prompts,
@@ -459,3 +441,78 @@ def query_llm(
459441
return sentences
460442
else:
461443
return result_dict["outputs"]
444+
445+
446+
class NemoQueryTRTLLMAPI(NemoQueryLLMBase):
447+
"""Sends a query to Triton for TensorRT-LLM API deployment inference.
448+
449+
Example:
450+
from nemo_deploy import NemoQueryTRTLLMAPI
451+
452+
nq = NemoQueryTRTLLMAPI(url="localhost", model_name="GPT-2B")
453+
454+
prompts = ["hello, testing GPT inference", "another GPT inference test?"]
455+
output = nq.query_llm(
456+
prompts=prompts,
457+
max_length=100,
458+
top_k=1,
459+
top_p=None,
460+
temperature=None,
461+
)
462+
print("prompts: ", prompts)
463+
"""
464+
465+
def query_llm(
466+
self,
467+
prompts: List[str],
468+
max_length: int = 256,
469+
top_k: Optional[int] = None,
470+
top_p: Optional[float] = None,
471+
temperature: Optional[float] = None,
472+
init_timeout: float = 60.0,
473+
):
474+
"""
475+
Query the Triton server synchronously and return a list of responses.
476+
477+
Args:
478+
prompts (List(str)): list of sentences.
479+
max_length (int): max generated tokens.
480+
top_k (int): limits us to a certain number (K) of the top tokens to consider.
481+
top_p (float): limits us to the top tokens within a certain probability mass (p).
482+
temperature (float): A parameter of the softmax function, which is the last layer in the network.
483+
init_timeout (flat): timeout for the connection.
484+
485+
Returns:
486+
List[str]: A list of generated texts, one for each input prompt.
487+
"""
488+
prompts = str_list2numpy(prompts)
489+
inputs = {
490+
"prompts": prompts,
491+
}
492+
493+
if max_length is not None:
494+
inputs["max_length"] = np.full(prompts.shape, max_length, dtype=np.int_)
495+
496+
if temperature is not None:
497+
inputs["temperature"] = np.full(prompts.shape, temperature, dtype=np.single)
498+
499+
if top_k is not None:
500+
inputs["top_k"] = np.full(prompts.shape, top_k, dtype=np.int_)
501+
502+
if top_p is not None:
503+
inputs["top_p"] = np.full(prompts.shape, top_p, dtype=np.single)
504+
505+
with ModelClient(self.url, self.model_name, init_timeout_s=init_timeout) as client:
506+
result_dict = client.infer_batch(**inputs)
507+
output_type = client.model_config.outputs[0].dtype
508+
509+
if output_type == np.bytes_:
510+
if "sentences" in result_dict.keys():
511+
output = result_dict["sentences"]
512+
else:
513+
return "Unknown output keyword."
514+
515+
sentences = np.char.decode(output.astype("bytes"), "utf-8")
516+
return sentences
517+
else:
518+
return result_dict["sentences"]
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
from pathlib import Path
17+
from typing import List, Optional, Union
18+
19+
import numpy as np
20+
from pytriton.decorators import batch, first_value
21+
from pytriton.model_config import Tensor
22+
from tensorrt_llm import SamplingParams
23+
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
24+
from tensorrt_llm.llmapi.llm import LLM, TokenizerBase
25+
from transformers import PreTrainedTokenizerBase
26+
27+
from nemo_deploy import ITritonDeployable
28+
from nemo_deploy.utils import cast_output, str_ndarray2list
29+
30+
LOGGER = logging.getLogger("NeMo")
31+
32+
33+
class TensorRTLLMAPIDeployable(ITritonDeployable):
34+
"""A Triton inference server compatible wrapper for TensorRT-LLM LLM-API.
35+
36+
This class provides a standardized interface for deploying TensorRT-LLM LLM-API
37+
in Triton inference server. It handles model loading, inference, and deployment configurations.
38+
39+
Args:
40+
hf_model_id_path (str): Path to the HuggingFace model or model identifier.
41+
Can be a local path or a model ID from HuggingFace Hub.
42+
tokenizer (Optional[Union[str, Path, TokenizerBase, PreTrainedTokenizerBase]]):
43+
Path to the tokenizer or tokenizer instance.
44+
tensor_parallel_size (int): Tensor parallelism size. Defaults to 1.
45+
pipeline_parallel_size (int): Pipeline parallelism size. Defaults to 1.
46+
moe_expert_parallel_size (int): MOE expert parallelism size. Defaults to -1.
47+
moe_tensor_parallel_size (int): MOE tensor parallelism size. Defaults to -1.
48+
max_batch_size (int): Maximum batch size. Defaults to 8.
49+
max_num_tokens (int): Maximum total tokens across all sequences in a batch. Defaults to 8192.
50+
backend (str): Backend to use for TRTLLM. Defaults to "pytorch".
51+
dtype (str): Model data type. Defaults to "auto".
52+
**kwargs: Additional keyword arguments to pass to model loading.
53+
"""
54+
55+
def __init__(
56+
self,
57+
hf_model_id_path: str,
58+
tokenizer: Optional[Union[str, Path, TokenizerBase, PreTrainedTokenizerBase]] = None,
59+
tensor_parallel_size: int = 1,
60+
pipeline_parallel_size: int = 1,
61+
moe_expert_parallel_size: int = -1,
62+
moe_tensor_parallel_size: int = -1,
63+
max_batch_size: int = 8,
64+
max_num_tokens: int = 8192,
65+
backend: str = "pytorch",
66+
dtype: str = "auto",
67+
**kwargs,
68+
):
69+
config_args = {k: kwargs.pop(k) for k in PyTorchConfig.__annotations__.keys() & kwargs.keys()}
70+
pytorch_config = PyTorchConfig(**config_args)
71+
72+
self.model = LLM(
73+
model=hf_model_id_path,
74+
tokenizer=hf_model_id_path if tokenizer is None else tokenizer,
75+
tensor_parallel_size=tensor_parallel_size,
76+
pipeline_parallel_size=pipeline_parallel_size,
77+
moe_expert_parallel_size=moe_expert_parallel_size,
78+
moe_tensor_parallel_size=moe_tensor_parallel_size,
79+
max_batch_size=max_batch_size,
80+
max_num_tokens=max_num_tokens,
81+
backend=backend,
82+
dtype=dtype,
83+
pytorch_backend_config=pytorch_config,
84+
**kwargs,
85+
)
86+
87+
def generate(
88+
self,
89+
prompts: List[str],
90+
max_length: int = 256,
91+
temperature: Optional[float] = None,
92+
top_k: Optional[int] = None,
93+
top_p: Optional[float] = None,
94+
**kwargs,
95+
) -> List[str]:
96+
"""Generate text based on the provided input prompts.
97+
98+
This method processes input prompts through the loaded model and
99+
generates text according to the specified parameters.
100+
101+
Args:
102+
prompts: List of input prompts
103+
max_length: Maximum number of tokens to generate. Defaults to 256.
104+
temperature: Sampling temperature. Defaults to None.
105+
top_k: Number of highest probability tokens to consider. Defaults to None.
106+
top_p: Cumulative probability threshold for token sampling. Defaults to None.
107+
**kwargs: Additional keyword arguments to sampling params.
108+
109+
Returns:
110+
List[str]: A list of generated texts, one for each input prompt.
111+
112+
Raises:
113+
RuntimeError: If the model is not initialized.
114+
"""
115+
if not self.model:
116+
raise RuntimeError("Model is not initialized")
117+
118+
sampling_params = SamplingParams(
119+
max_tokens=max_length,
120+
temperature=temperature,
121+
top_k=top_k,
122+
top_p=top_p,
123+
**kwargs,
124+
)
125+
126+
outputs = self.model.generate(
127+
inputs=prompts,
128+
sampling_params=sampling_params,
129+
)
130+
131+
return [output.outputs[0].text for output in outputs]
132+
133+
@property
134+
def get_triton_input(self):
135+
inputs = (
136+
Tensor(name="prompts", shape=(-1,), dtype=bytes),
137+
Tensor(name="max_length", shape=(-1,), dtype=np.int_, optional=True),
138+
Tensor(name="max_batch_size", shape=(-1,), dtype=np.int_, optional=True),
139+
Tensor(name="top_k", shape=(-1,), dtype=np.int_, optional=True),
140+
Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True),
141+
Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True),
142+
)
143+
return inputs
144+
145+
@property
146+
def get_triton_output(self):
147+
return (Tensor(name="sentences", shape=(-1,), dtype=bytes),)
148+
149+
@batch
150+
@first_value("temperature", "top_k", "top_p", "max_length")
151+
def triton_infer_fn(self, **inputs: np.ndarray):
152+
output_infer = {}
153+
154+
prompts = str_ndarray2list(inputs.pop("prompts"))
155+
temperature = inputs.pop("temperature", None)
156+
top_k = inputs.pop("top_k", None)
157+
top_p = inputs.pop("top_p", None)
158+
max_length = inputs.pop("max_length", 256)
159+
160+
output = self.generate(
161+
prompts=prompts,
162+
temperature=temperature,
163+
top_k=top_k,
164+
top_p=top_p,
165+
max_length=max_length,
166+
)
167+
168+
output_infer = {"sentences": cast_output(output, np.bytes_)}
169+
170+
return output_infer

0 commit comments

Comments
 (0)