Skip to content

Commit 159e061

Browse files
authored
Fix TRTLLM API (#301)
1 parent 2ac304c commit 159e061

File tree

3 files changed

+7
-21
lines changed

3 files changed

+7
-21
lines changed

nemo_deploy/nlp/trtllm_api_deployable.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838

3939
try:
4040
from tensorrt_llm import SamplingParams
41-
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
4241
from tensorrt_llm.llmapi.llm import LLM, TokenizerBase
4342

4443
HAVE_TENSORRT_LLM = True
@@ -90,9 +89,6 @@ def __init__(
9089
if not HAVE_TRITON:
9190
raise ImportError(MISSING_TRITON_MSG)
9291

93-
config_args = {k: kwargs.pop(k) for k in PyTorchConfig.__annotations__.keys() & kwargs.keys()}
94-
pytorch_config = PyTorchConfig(**config_args)
95-
9692
self.model = LLM(
9793
model=hf_model_id_path,
9894
tokenizer=hf_model_id_path if tokenizer is None else tokenizer,
@@ -104,7 +100,6 @@ def __init__(
104100
max_num_tokens=max_num_tokens,
105101
backend=backend,
106102
dtype=dtype,
107-
pytorch_backend_config=pytorch_config,
108103
**kwargs,
109104
)
110105

scripts/deploy/nlp/deploy_trtllm_api_triton.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@ def get_args():
4848
)
4949
parser.add_argument("-dt", "--dtype", default="auto", type=str, help="Model data type")
5050
parser.add_argument("-ab", "--attn_backend", default="TRTLLM", type=str, help="Attention kernel backend")
51-
parser.add_argument("-eos", "--enable_overlap_scheduler", action="store_true", help="Enable overlap scheduler")
51+
parser.add_argument("-dos", "--disable_overlap_scheduler", action="store_true", help="Disable overlap scheduler")
5252
parser.add_argument("-ecp", "--enable_chunked_prefill", action="store_true", help="Enable chunked prefill")
53-
parser.add_argument("-ucg", "--use_cuda_graph", action="store_true", help="Use CUDA graph")
5453
parser.add_argument("-dm", "--debug_mode", action="store_true", help="Enable debug mode")
5554
args = parser.parse_args()
5655
return args
@@ -79,9 +78,8 @@ def trtllm_deploy():
7978
max_num_tokens=args.max_num_tokens,
8079
dtype=args.dtype,
8180
attn_backend=args.attn_backend,
82-
enable_overlap_scheduler=args.enable_overlap_scheduler,
81+
disable_overlap_scheduler=args.disable_overlap_scheduler,
8382
enable_chunked_prefill=args.enable_chunked_prefill,
84-
use_cuda_graph=args.use_cuda_graph,
8583
)
8684

8785
try:

tests/unit_tests/deploy/test_trtllm_api_deployable.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,6 @@ def mock_sampling_params():
3737
yield mock
3838

3939

40-
@pytest.fixture
41-
def mock_pytorch_config():
42-
with patch("nemo_deploy.nlp.trtllm_api_deployable.PyTorchConfig") as mock:
43-
mock.__annotations__ = {}
44-
yield mock
45-
46-
4740
try:
4841
import tensorrt_llm # noqa: F401
4942

@@ -55,7 +48,7 @@ def mock_pytorch_config():
5548
@pytest.mark.skipif(not HAVE_TENSORRT_LLM, reason="TensorRT-LLM is not installed")
5649
@pytest.mark.run_only_on("GPU")
5750
class TestTensorRTLLMAPIDeployable:
58-
def test_initialization_with_defaults(self, mock_pytorch_config):
51+
def test_initialization_with_defaults(self):
5952
from nemo_deploy.nlp.trtllm_api_deployable import TensorRTLLMAPIDeployable
6053

6154
with patch("nemo_deploy.nlp.trtllm_api_deployable.LLM") as mock_llm_class:
@@ -67,7 +60,7 @@ def test_initialization_with_defaults(self, mock_pytorch_config):
6760
assert deployer.model == mock_llm_instance
6861
mock_llm_class.assert_called_once()
6962

70-
def test_initialization_with_custom_params(self, mock_pytorch_config):
63+
def test_initialization_with_custom_params(self):
7164
from nemo_deploy.nlp.trtllm_api_deployable import TensorRTLLMAPIDeployable
7265

7366
with patch("nemo_deploy.nlp.trtllm_api_deployable.LLM") as mock_llm_class:
@@ -109,7 +102,7 @@ def test_generate_without_model(self):
109102
with pytest.raises(RuntimeError, match="Model is not initialized"):
110103
deployer.generate(prompts=["test prompt"])
111104

112-
def test_generate_with_model(self, mock_llm, mock_sampling_params, mock_pytorch_config):
105+
def test_generate_with_model(self, mock_llm, mock_sampling_params):
113106
from nemo_deploy.nlp.trtllm_api_deployable import TensorRTLLMAPIDeployable
114107

115108
with patch("nemo_deploy.nlp.trtllm_api_deployable.LLM") as mock_llm_class:
@@ -122,7 +115,7 @@ def test_generate_with_model(self, mock_llm, mock_sampling_params, mock_pytorch_
122115
mock_llm.generate.assert_called_once()
123116
mock_sampling_params.assert_called_once()
124117

125-
def test_generate_with_parameters(self, mock_llm, mock_sampling_params, mock_pytorch_config):
118+
def test_generate_with_parameters(self, mock_llm, mock_sampling_params):
126119
from nemo_deploy.nlp.trtllm_api_deployable import TensorRTLLMAPIDeployable
127120

128121
with patch("nemo_deploy.nlp.trtllm_api_deployable.LLM") as mock_llm_class:
@@ -135,7 +128,7 @@ def test_generate_with_parameters(self, mock_llm, mock_sampling_params, mock_pyt
135128
mock_llm.generate.assert_called_once()
136129
mock_sampling_params.assert_called_once_with(max_tokens=100, temperature=0.8, top_k=50, top_p=0.95)
137130

138-
def test_triton_input_output_config(self, mock_pytorch_config):
131+
def test_triton_input_output_config(self):
139132
from nemo_deploy.nlp.trtllm_api_deployable import TensorRTLLMAPIDeployable
140133

141134
with patch("nemo_deploy.nlp.trtllm_api_deployable.LLM"):

0 commit comments

Comments
 (0)