Skip to content

Commit 74df9b1

Browse files
[#9602][feat] AutoDeploy: Support TRTLLM Sampler (#9641)
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent cb87c44 commit 74df9b1

File tree

3 files changed

+127
-12
lines changed

3 files changed

+127
-12
lines changed

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from tensorrt_llm.models.modeling_utils import QuantConfig
1010

11-
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, _ParallelConfig
11+
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, SamplerType, _ParallelConfig
1212
from .models import ModelFactory, ModelFactoryRegistry
1313
from .utils._config import DynamicYamlMixInForSettings
1414
from .utils.logger import ad_logger
@@ -130,6 +130,11 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
130130
"supported in AutoDeploy.",
131131
)
132132

133+
sampler_type: Union[str, SamplerType] = Field(
134+
default=SamplerType.TorchSampler,
135+
description="The type of sampler to use. Options are TRTLLMSampler or TorchSampler. Defaults to TorchSampler.",
136+
)
137+
133138
# NOTE: we do not support copy_on_partial_reuse in AutoDeploy yet
134139
# see https://github.com/NVIDIA/TensorRT-LLM/issues/7142
135140
kv_cache_config: KvCacheConfig = Field(

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
from torch._prims_common import DeviceLikeType
2121

2222
from tensorrt_llm._torch.attention_backend.interface import AttentionRuntimeFeatures
23-
from tensorrt_llm._torch.pyexecutor._util import _create_kv_cache_manager, get_kv_cache_manager_cls
23+
from tensorrt_llm._torch.pyexecutor._util import (
24+
_create_kv_cache_manager,
25+
get_decoding_mode,
26+
get_kv_cache_manager_cls,
27+
)
2428
from tensorrt_llm._torch.pyexecutor.guided_decoder import GuidedDecoder
2529
from tensorrt_llm._torch.pyexecutor.llm_request import get_draft_token_length
2630
from tensorrt_llm._torch.pyexecutor.py_executor_creator import get_guided_decoding_config
@@ -30,6 +34,7 @@
3034
from tensorrt_llm.llmapi.llm_args import (
3135
ContextChunkingPolicy,
3236
LoadFormat,
37+
SamplerType,
3338
SpeculativeConfig,
3439
TorchLlmArgs,
3540
)
@@ -42,7 +47,7 @@
4247
from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine
4348
from ...pyexecutor.py_executor import PyExecutor
4449
from ...pyexecutor.resource_manager import KVCacheManager, ResourceManager, ResourceManagerType
45-
from ...pyexecutor.sampler import TorchSampler
50+
from ...pyexecutor.sampler import TorchSampler, TRTLLMSampler
4651
from ...pyexecutor.scheduler import (
4752
BindCapacityScheduler,
4853
BindMicroBatchScheduler,
@@ -283,9 +288,9 @@ def __init__(
283288
self.llm_args.batch_wait_timeout_iters = 0
284289
self.llm_args.batch_wait_max_tokens_ratio = 0.0
285290
self.llm_args.max_num_tokens = seq_info.max_num_tokens
291+
self.llm_args.max_seq_len = seq_info.max_seq_len
286292
self.iter_counter = 0
287293
self.iter_states = {}
288-
self.llm_args.max_seq_len = seq_info.max_seq_len
289294

290295
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
291296
self.max_beam_width = max_beam_width
@@ -487,6 +492,9 @@ def _compute_logits(self) -> List[torch.Tensor]:
487492
# run the model
488493
logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0]
489494

495+
# TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless.
496+
logits = logits.float()
497+
490498
# return a list of tensors
491499
return self.cache_seq_interface.info.unnest_sequences(logits)
492500

@@ -574,6 +582,59 @@ def create_draft_model_engine_maybe(
574582
return draft_model_engine
575583

576584

585+
class TRTLLMSamplerModelConfig:
586+
def __init__(self, vocab_size_padded: int):
587+
self.config = SimpleNamespace()
588+
self.config.vocab_size = vocab_size_padded
589+
590+
# Initialized to dummy values as they are not used in the C++ code underlying TRTLLMSampler.
591+
self.config.num_hidden_layers = 42
592+
self.config.hidden_size = 42
593+
self.config.num_attention_heads = 42
594+
595+
596+
def instantiate_sampler(
597+
ad_config: LlmArgs,
598+
max_num_sequences: int,
599+
max_draft_len: int,
600+
max_total_draft_tokens: int,
601+
dist_mapping: Mapping,
602+
engine: ADEngine,
603+
):
604+
if ad_config.sampler_type == SamplerType.TorchSampler:
605+
# search sampler with speculative decoding
606+
sampler_args = TorchSampler.Args(
607+
max_seq_len=ad_config.max_seq_len,
608+
max_draft_len=max_draft_len,
609+
max_total_draft_tokens=max_total_draft_tokens,
610+
max_num_sequences=max_num_sequences,
611+
max_beam_width=ad_config.max_beam_width,
612+
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
613+
)
614+
sampler = TorchSampler(sampler_args)
615+
616+
elif ad_config.sampler_type == SamplerType.TRTLLMSampler:
617+
vocab_size_padded: int = engine.cache_seq_interface.info.vocab_size_padded
618+
sampler_model_config = TRTLLMSamplerModelConfig(vocab_size_padded)
619+
decoding_mode = get_decoding_mode(ad_config.decoding_config, ad_config.max_beam_width)
620+
sampler = TRTLLMSampler(
621+
model=sampler_model_config,
622+
model_dtype=torch.bfloat16, # hardcoded as bfloat16; does not seem necessary in C++ code.
623+
mapping=dist_mapping,
624+
decoding_mode=decoding_mode,
625+
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
626+
max_seq_len=ad_config.max_seq_len,
627+
max_batch_size=ad_config.max_batch_size,
628+
max_beam_width=ad_config.max_beam_width,
629+
decoding_config=ad_config.decoding_config,
630+
kv_cache_config=ad_config.kv_cache_config,
631+
)
632+
else:
633+
raise ValueError(f"Sampler type {ad_config.sampler_type} is not supported.")
634+
635+
return sampler
636+
637+
577638
def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[TokenizerBase] = None):
578639
"""Create an AutoDeploy executor from the given configuration and tokenizer.
579640
The tokenizer is required for guided decoding.
@@ -695,23 +756,21 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
695756
)
696757
scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler)
697758

698-
# search sampler with speculative decoding
699-
sampler_args = TorchSampler.Args(
700-
max_seq_len=ad_config.max_seq_len,
759+
vocab_size_padded = engine.cache_seq_interface.info.vocab_size_padded
760+
sampler = instantiate_sampler(
761+
ad_config=ad_config,
762+
max_num_sequences=max_num_sequences,
701763
max_draft_len=max_draft_len,
702764
max_total_draft_tokens=max_total_draft_tokens,
703-
max_num_sequences=max_num_sequences,
704-
max_beam_width=ad_config.max_beam_width,
705-
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
765+
dist_mapping=dist_mapping,
766+
engine=engine,
706767
)
707-
sampler = TorchSampler(sampler_args)
708768

709769
# Guided (structured) decoding.
710770
guided_decoder = None
711771
if (
712772
(guided_decoding_backend := ad_config.guided_decoding_backend) is not None
713773
) and dist_mapping.is_last_pp_rank():
714-
vocab_size_padded = engine.cache_seq_interface.info.vocab_size_padded
715774
if vocab_size_padded is None:
716775
raise RuntimeError(
717776
"Could not determine the vocabulary size. Required for guided decoding."
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from _model_test_utils import get_small_model_config
17+
from build_and_run_ad import ExperimentConfig, main
18+
19+
from tensorrt_llm.llmapi.llm_args import SamplerType
20+
21+
22+
def test_ad_trtllm_sampler_smoke():
23+
"""Test TRTLLMSampler in AutoDeploy smoke test."""
24+
# Get small model config
25+
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
26+
experiment_config = get_small_model_config(model_id)
27+
28+
# Configure for TRTLLMSampler
29+
experiment_config["args"]["runtime"] = "trtllm"
30+
experiment_config["args"]["world_size"] = 1
31+
experiment_config["args"]["sampler_type"] = SamplerType.TRTLLMSampler
32+
33+
# Setup simple prompt
34+
experiment_config["prompt"]["batch_size"] = 1
35+
experiment_config["prompt"]["queries"] = {"prompt": "What is the capital of France?"}
36+
experiment_config["prompt"]["sp_kwargs"] = {
37+
"max_tokens": 10,
38+
"temperature": 1.0,
39+
"top_k": 1,
40+
}
41+
42+
print(f"Experiment config: {experiment_config}")
43+
cfg = ExperimentConfig(**experiment_config)
44+
45+
print("Running smoke test with TRTLLMSampler...")
46+
results = main(cfg)
47+
48+
# Basic assertion that we got some output
49+
prompts_and_outputs = results["prompts_and_outputs"]
50+
assert len(prompts_and_outputs) == 1
51+
assert len(prompts_and_outputs[0][1]) > 0

0 commit comments

Comments
 (0)